Repository: kafka
Updated Branches:
  refs/heads/0.10.1 f91d95ac9 -> ecb51680a


KAFKA-4311: Multi layer cache eviction causes forwarding to incorrect 
ProcessorNode

Given a topology like the one below. If a record arriving in `tableOne` causes 
a cache eviction, it will trigger the `leftJoin` that will do a `get` from 
`reducer-store`. If the key is not currently cached in `reducer-store`, but is 
in the backing store, it will be put into the cache, and it may also trigger an 
eviction. If it does trigger an eviction and the eldest entry is dirty it will 
flush the dirty keys. It is at this point that a ClassCastException is thrown. 
This occurs because the ProcessorContext is still set to the context of the 
`leftJoin` and the next child in the topology is `mapValues`.
We need to set the correct `ProcessorNode`, on the context, in the 
`ForwardingCacheFlushListener` prior to calling `context.forward`. We also need 
to  remember to reset the `ProcessorNode` to the previous node once 
`context.forward` has completed.

```
       final KTable<String, String> one = builder.table(Serdes.String(), 
Serdes.String(), tableOne, tableOne);
        final KTable<Long, String> two = builder.table(Serdes.Long(), 
Serdes.String(), tableTwo, tableTwo);
        final KTable<String, Long> reduce = two.groupBy(new 
KeyValueMapper<Long, String, KeyValue<String, Long>>() {
            Override
            public KeyValue<String, Long> apply(final Long key, final String 
value) {
                return new KeyValue<>(value, key);
            }
        }, Serdes.String(), Serdes.Long())
                .reduce(new Reducer<Long>() {..}, new Reducer<Long>() {..}, 
"reducer-store");

    one.leftJoin(reduce, new ValueJoiner<String, Long, String>() {..})
        .mapValues(new ValueMapper<String, String>() {..});

```

Author: Damian Guy <[email protected]>

Reviewers: Eno Thereska, Guozhang Wang

Closes #2051 from dguy/kafka-4311


Project: http://git-wip-us.apache.org/repos/asf/kafka/repo
Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/c1fb615a
Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/c1fb615a
Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/c1fb615a

Branch: refs/heads/0.10.1
Commit: c1fb615a6ef6cbd0a725bd70a2a602ca31402f8a
Parents: f91d95a
Author: Damian Guy <[email protected]>
Authored: Wed Nov 9 10:43:27 2016 -0800
Committer: Guozhang Wang <[email protected]>
Committed: Wed Nov 23 07:52:57 2016 -0800

----------------------------------------------------------------------
 .../internals/ForwardingCacheFlushListener.java | 22 ++++--
 .../internals/InternalProcessorContext.java     |  1 +
 .../internals/ProcessorContextImpl.java         | 13 ++--
 .../processor/internals/StandbyContextImpl.java |  5 ++
 .../streams/processor/internals/StreamTask.java | 28 +++-----
 .../streams/state/internals/NamedCache.java     |  8 ++-
 .../kstream/internals/KTableAggregateTest.java  | 72 ++++++++++++++++++++
 .../internals/ProcessorTopologyTest.java        |  1 +
 .../streams/state/KeyValueStoreTestDriver.java  |  7 ++
 .../streams/state/internals/NamedCacheTest.java |  6 ++
 .../apache/kafka/test/KStreamTestDriver.java    | 23 ++++++-
 .../apache/kafka/test/MockProcessorContext.java |  7 ++
 12 files changed, 160 insertions(+), 33 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kafka/blob/c1fb615a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/ForwardingCacheFlushListener.java
----------------------------------------------------------------------
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/ForwardingCacheFlushListener.java
 
b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/ForwardingCacheFlushListener.java
index 1796be9..4635fc9 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/ForwardingCacheFlushListener.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/ForwardingCacheFlushListener.java
@@ -17,22 +17,32 @@
 package org.apache.kafka.streams.kstream.internals;
 
 import org.apache.kafka.streams.processor.ProcessorContext;
+import org.apache.kafka.streams.processor.internals.InternalProcessorContext;
+import org.apache.kafka.streams.processor.internals.ProcessorNode;
 
 class ForwardingCacheFlushListener<K, V> implements CacheFlushListener<K, V> {
-    private final ProcessorContext context;
+    private final InternalProcessorContext context;
     private final boolean sendOldValues;
+    private final ProcessorNode myNode;
 
     ForwardingCacheFlushListener(final ProcessorContext context, final boolean 
sendOldValues) {
-        this.context = context;
+        this.context = (InternalProcessorContext) context;
+        myNode = this.context.currentNode();
         this.sendOldValues = sendOldValues;
     }
 
     @Override
     public void apply(final K key, final V newValue, final V oldValue) {
-        if (sendOldValues) {
-            context.forward(key, new Change<>(newValue, oldValue));
-        } else {
-            context.forward(key, new Change<>(newValue, null));
+        final ProcessorNode prev = context.currentNode();
+        context.setCurrentNode(myNode);
+        try {
+            if (sendOldValues) {
+                context.forward(key, new Change<>(newValue, oldValue));
+            } else {
+                context.forward(key, new Change<>(newValue, null));
+            }
+        } finally {
+            context.setCurrentNode(prev);
         }
     }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/c1fb615a/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalProcessorContext.java
----------------------------------------------------------------------
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalProcessorContext.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalProcessorContext.java
index 251ff3f..016964b 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalProcessorContext.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalProcessorContext.java
@@ -42,6 +42,7 @@ public interface InternalProcessorContext extends 
ProcessorContext {
      */
     void setCurrentNode(ProcessorNode currentNode);
 
+    ProcessorNode currentNode();
     /**
      * Get the thread-global cache
      */

http://git-wip-us.apache.org/repos/asf/kafka/blob/c1fb615a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java
----------------------------------------------------------------------
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java
index 195e5a4..be18593 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java
@@ -128,13 +128,11 @@ public class ProcessorContextImpl implements 
InternalProcessorContext, RecordCol
      */
     @Override
     public StateStore getStateStore(String name) {
-        ProcessorNode node = task.node();
-
-        if (node == null)
+        if (currentNode == null)
             throw new TopologyBuilderException("Accessing from an unknown 
node");
 
-        if (!node.stateStores.contains(name)) {
-            throw new TopologyBuilderException("Processor " + node.name() + " 
has no access to StateStore " + name);
+        if (!currentNode.stateStores.contains(name)) {
+            throw new TopologyBuilderException("Processor " + 
currentNode.name() + " has no access to StateStore " + name);
         }
 
         return stateMgr.getStore(name);
@@ -272,4 +270,9 @@ public class ProcessorContextImpl implements 
InternalProcessorContext, RecordCol
     public void setCurrentNode(final ProcessorNode currentNode) {
         this.currentNode = currentNode;
     }
+
+    @Override
+    public ProcessorNode currentNode() {
+        return currentNode;
+    }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/c1fb615a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyContextImpl.java
----------------------------------------------------------------------
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyContextImpl.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyContextImpl.java
index 563dbce..80c0026 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyContextImpl.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyContextImpl.java
@@ -222,4 +222,9 @@ public class StandbyContextImpl implements 
InternalProcessorContext, RecordColle
     public void setCurrentNode(final ProcessorNode currentNode) {
         // no-op. can't throw as this is called on commit when the StateStores 
get flushed.
     }
+
+    @Override
+    public ProcessorNode currentNode() {
+        throw new UnsupportedOperationException("this should not happen: 
currentNode not supported in standby tasks.");
+    }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/c1fb615a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
----------------------------------------------------------------------
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
index b993054..9a2f03e 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
@@ -59,7 +59,6 @@ public class StreamTask extends AbstractTask implements 
Punctuator {
 
     private boolean commitRequested = false;
     private boolean commitOffsetNeeded = false;
-    private ProcessorNode currNode = null;
 
     private boolean requiresPoll = true;
 
@@ -122,11 +121,11 @@ public class StreamTask extends AbstractTask implements 
Punctuator {
         // initialize the task by initializing all its processor nodes in the 
topology
         log.info("{} Initializing processor nodes of the topology", logPrefix);
         for (ProcessorNode node : this.topology.processors()) {
-            this.currNode = node;
+            processorContext.setCurrentNode(node);
             try {
                 node.init(this.processorContext);
             } finally {
-                this.currNode = null;
+                processorContext.setCurrentNode(null);
             }
         }
 
@@ -172,13 +171,13 @@ public class StreamTask extends AbstractTask implements 
Punctuator {
 
         try {
             // process the record by passing to the source node of the topology
-            this.currNode = recordInfo.node();
+            final ProcessorNode currNode = recordInfo.node();
             TopicPartition partition = recordInfo.partition();
 
             log.trace("{} Start processing one record [{}]", logPrefix, 
record);
             final ProcessorRecordContext recordContext = 
createRecordContext(record);
             updateProcessorContext(recordContext, currNode);
-            this.currNode.process(record.key(), record.value());
+            currNode.process(record.key(), record.value());
 
             log.trace("{} Completed processing one record [{}]", logPrefix, 
record);
 
@@ -199,14 +198,13 @@ public class StreamTask extends AbstractTask implements 
Punctuator {
         } catch (KafkaException ke) {
             throw new StreamsException(format("Exception caught in process. 
taskId=%s, processor=%s, topic=%s, partition=%d, offset=%d",
                                               id.toString(),
-                                              currNode.name(),
+                                              
processorContext.currentNode().name(),
                                               record.topic(),
                                               record.partition(),
                                               record.offset()
                                               ), ke);
         } finally {
             processorContext.setCurrentNode(null);
-            this.currNode = null;
         }
 
         return partitionGroup.numBuffered();
@@ -241,10 +239,9 @@ public class StreamTask extends AbstractTask implements 
Punctuator {
      */
     @Override
     public void punctuate(ProcessorNode node, long timestamp) {
-        if (currNode != null)
+        if (processorContext.currentNode() != null)
             throw new IllegalStateException(String.format("%s Current node is 
not null", logPrefix));
 
-        currNode = node;
         final StampedRecord stampedRecord = new StampedRecord(DUMMY_RECORD, 
timestamp);
         updateProcessorContext(createRecordContext(stampedRecord), node);
 
@@ -256,15 +253,10 @@ public class StreamTask extends AbstractTask implements 
Punctuator {
             throw new StreamsException(String.format("Exception caught in 
punctuate. taskId=%s processor=%s", id,  node.name()), ke);
         } finally {
             processorContext.setCurrentNode(null);
-            currNode = null;
         }
     }
 
 
-    public ProcessorNode node() {
-        return this.currNode;
-    }
-
     /**
      * Commit the current task state
      */
@@ -322,10 +314,10 @@ public class StreamTask extends AbstractTask implements 
Punctuator {
      * @throws IllegalStateException if the current node is not null
      */
     public void schedule(long interval) {
-        if (currNode == null)
+        if (processorContext.currentNode() == null)
             throw new IllegalStateException(String.format("%s Current node is 
null", logPrefix));
 
-        punctuationQueue.schedule(new PunctuationSchedule(currNode, interval));
+        punctuationQueue.schedule(new 
PunctuationSchedule(processorContext.currentNode(), interval));
     }
 
     /**
@@ -342,13 +334,13 @@ public class StreamTask extends AbstractTask implements 
Punctuator {
         // make sure close() is called for each node even when there is a 
RuntimeException
         RuntimeException exception = null;
         for (ProcessorNode node : this.topology.processors()) {
-            currNode = node;
+            processorContext.setCurrentNode(node);
             try {
                 node.close();
             } catch (RuntimeException e) {
                 exception = e;
             } finally {
-                currNode = null;
+                processorContext.setCurrentNode(null);
             }
         }
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/c1fb615a/streams/src/main/java/org/apache/kafka/streams/state/internals/NamedCache.java
----------------------------------------------------------------------
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/NamedCache.java
 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/NamedCache.java
index 65a836e..ab771df 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/NamedCache.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/NamedCache.java
@@ -109,7 +109,7 @@ class NamedCache {
         for (Bytes key : dirtyKeys) {
             final LRUNode node = getInternal(key);
             if (node == null) {
-                throw new IllegalStateException("Key found in dirty key set, 
but entry is null");
+                throw new IllegalStateException("Key = " + key + " found in 
dirty key set, but entry is null");
             }
             entries.add(new ThreadCache.DirtyEntry(key, node.entry.value, 
node.entry));
             node.entry.markClean();
@@ -120,6 +120,12 @@ class NamedCache {
 
 
     synchronized void put(final Bytes key, final LRUCacheEntry value) {
+        if (!value.isDirty && dirtyKeys.contains(key)) {
+            throw new IllegalStateException(String.format("Attempting to put a 
clean entry for key [%s] " +
+                                                                  "into 
NamedCache [%s] when it already contains " +
+                                                                  "a dirty 
entry for the same key",
+                                                          key, name));
+        }
         LRUNode node = cache.get(key);
         if (node != null) {
             numOverwrites++;

http://git-wip-us.apache.org/repos/asf/kafka/blob/c1fb615a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableAggregateTest.java
----------------------------------------------------------------------
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableAggregateTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableAggregateTest.java
index ba33d5c..8378a79 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableAggregateTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableAggregateTest.java
@@ -22,10 +22,14 @@ import org.apache.kafka.common.serialization.Serdes;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.KeyValue;
 import org.apache.kafka.streams.kstream.Aggregator;
+import org.apache.kafka.streams.kstream.ForeachAction;
 import org.apache.kafka.streams.kstream.Initializer;
 import org.apache.kafka.streams.kstream.KStreamBuilder;
 import org.apache.kafka.streams.kstream.KTable;
 import org.apache.kafka.streams.kstream.KeyValueMapper;
+import org.apache.kafka.streams.kstream.Reducer;
+import org.apache.kafka.streams.kstream.ValueJoiner;
+import org.apache.kafka.streams.kstream.ValueMapper;
 import org.apache.kafka.test.KStreamTestDriver;
 import org.apache.kafka.test.MockAggregator;
 import org.apache.kafka.test.MockInitializer;
@@ -39,6 +43,8 @@ import org.junit.Test;
 
 import java.io.File;
 import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
 
 import static org.junit.Assert.assertEquals;
 
@@ -320,4 +326,70 @@ public class KTableAggregateTest {
                  "1:2"
                  ), proc.processed);
     }
+
+    @Test
+    public void shouldForwardToCorrectProcessorNodeWhenMultiCacheEvictions() 
throws Exception {
+        final String tableOne = "tableOne";
+        final String tableTwo = "tableTwo";
+        final KStreamBuilder builder = new KStreamBuilder();
+        final String reduceTopic = "TestDriver-reducer-store-repartition";
+        final Map<String, Long> reduceResults = new HashMap<>();
+
+        final KTable<String, String> one = builder.table(Serdes.String(), 
Serdes.String(), tableOne, tableOne);
+        final KTable<Long, String> two = builder.table(Serdes.Long(), 
Serdes.String(), tableTwo, tableTwo);
+
+
+        final KTable<String, Long> reduce = two.groupBy(new 
KeyValueMapper<Long, String, KeyValue<String, Long>>() {
+            @Override
+            public KeyValue<String, Long> apply(final Long key, final String 
value) {
+                return new KeyValue<>(value, key);
+            }
+        }, Serdes.String(), Serdes.Long())
+                .reduce(new Reducer<Long>() {
+                    @Override
+                    public Long apply(final Long value1, final Long value2) {
+                        return value1 + value2;
+                    }
+                }, new Reducer<Long>() {
+                    @Override
+                    public Long apply(final Long value1, final Long value2) {
+                        return value1 - value2;
+                    }
+                }, "reducer-store");
+
+        reduce.foreach(new ForeachAction<String, Long>() {
+            @Override
+            public void apply(final String key, final Long value) {
+                reduceResults.put(key, value);
+            }
+        });
+
+        one.leftJoin(reduce, new ValueJoiner<String, Long, String>() {
+            @Override
+            public String apply(final String value1, final Long value2) {
+                return value1 + ":" + value2;
+            }
+        })
+                .mapValues(new ValueMapper<String, String>() {
+                    @Override
+                    public String apply(final String value) {
+                        return value;
+                    }
+                });
+
+        final KStreamTestDriver driver = new KStreamTestDriver(builder, 
stateDir, 111);
+        driver.process(reduceTopic, "1", new Change<>(1L, null));
+        driver.process("tableOne", "2", "2");
+        // this should trigger eviction on the reducer-store topic
+        driver.process(reduceTopic, "2", new Change<>(2L, null));
+        // this wont as it is the same value
+        driver.process(reduceTopic, "2", new Change<>(2L, null));
+        assertEquals(Long.valueOf(2L), reduceResults.get("2"));
+
+        // this will trigger eviction on the tableOne topic
+        // that in turn will cause an eviction on reducer-topic. It will flush
+        // key 2 as it is the only dirty entry in the cache
+        driver.process("tableOne", "1", "5");
+        assertEquals(Long.valueOf(4L), reduceResults.get("2"));
+    }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/c1fb615a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorTopologyTest.java
----------------------------------------------------------------------
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorTopologyTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorTopologyTest.java
index 54ee43c..a146316 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorTopologyTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorTopologyTest.java
@@ -280,6 +280,7 @@ public class ProcessorTopologyTest {
                 .addSink("sink-2", OUTPUT_TOPIC_2, 
constantPartitioner(partition), "processor-2");
     }
 
+
     /**
      * A processor that simply forwards all messages to all children.
      */

http://git-wip-us.apache.org/repos/asf/kafka/blob/c1fb615a/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java
----------------------------------------------------------------------
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java
 
b/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java
index e84e9ba..aca974b 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java
@@ -32,6 +32,7 @@ import 
org.apache.kafka.streams.processor.StateRestoreCallback;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.StreamPartitioner;
 import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.internals.ProcessorNode;
 import org.apache.kafka.streams.processor.internals.RecordCollector;
 import org.apache.kafka.streams.state.internals.ThreadCache;
 import org.apache.kafka.test.MockProcessorContext;
@@ -269,6 +270,12 @@ public class KeyValueStoreTestDriver<K, V> {
             public Map<String, Object> appConfigsWithPrefix(String prefix) {
                 return new StreamsConfig(props).originalsWithPrefix(prefix);
             }
+
+            @Override
+            public ProcessorNode currentNode() {
+                return null;
+            }
+
             @Override
             public ThreadCache getCache() {
                 return cache;

http://git-wip-us.apache.org/repos/asf/kafka/blob/c1fb615a/streams/src/test/java/org/apache/kafka/streams/state/internals/NamedCacheTest.java
----------------------------------------------------------------------
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/NamedCacheTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/NamedCacheTest.java
index 3067256..5c0d511 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/NamedCacheTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/NamedCacheTest.java
@@ -191,4 +191,10 @@ public class NamedCacheTest {
     public void shouldNotThrowNullPointerWhenCacheIsEmptyAndEvictionCalled() 
throws Exception {
         cache.evict();
     }
+
+    @Test(expected = IllegalStateException.class)
+    public void 
shouldThrowIllegalStateExceptionWhenTryingToOverwriteDirtyEntryWithCleanEntry() 
throws Exception {
+        cache.put(Bytes.wrap(new byte[]{0}), new LRUCacheEntry(new byte[]{10}, 
true, 0, 0, 0, ""));
+        cache.put(Bytes.wrap(new byte[]{0}), new LRUCacheEntry(new byte[]{10}, 
false, 0, 0, 0, ""));
+    }
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/kafka/blob/c1fb615a/streams/src/test/java/org/apache/kafka/test/KStreamTestDriver.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/test/KStreamTestDriver.java 
b/streams/src/test/java/org/apache/kafka/test/KStreamTestDriver.java
index ac58f37..05abbc6 100644
--- a/streams/src/test/java/org/apache/kafka/test/KStreamTestDriver.java
+++ b/streams/src/test/java/org/apache/kafka/test/KStreamTestDriver.java
@@ -56,14 +56,26 @@ public class KStreamTestDriver {
         this(builder, stateDir, Serdes.ByteArray(), Serdes.ByteArray());
     }
 
+    public KStreamTestDriver(KStreamBuilder builder, File stateDir, final long 
cacheSize) {
+        this(builder, stateDir, Serdes.ByteArray(), Serdes.ByteArray(), 
cacheSize);
+    }
+
     public KStreamTestDriver(KStreamBuilder builder,
                              File stateDir,
                              Serde<?> keySerde,
                              Serde<?> valSerde) {
+        this(builder, stateDir, keySerde, valSerde, DEFAULT_CACHE_SIZE_BYTES);
+    }
+
+    public KStreamTestDriver(KStreamBuilder builder,
+                             File stateDir,
+                             Serde<?> keySerde,
+                             Serde<?> valSerde,
+                             long cacheSize) {
         builder.setApplicationId("TestDriver");
         this.topology = builder.build(null);
         this.stateDir = stateDir;
-        this.cache = new ThreadCache(DEFAULT_CACHE_SIZE_BYTES);
+        this.cache = new ThreadCache(cacheSize);
         this.context = new MockProcessorContext(this, stateDir, keySerde, 
valSerde, new MockRecordCollector(), cache);
         this.context.setRecordContext(new ProcessorRecordContext(0, 0, 0, 
"topic"));
 
@@ -73,13 +85,14 @@ public class KStreamTestDriver {
         }
 
         for (ProcessorNode node : topology.processors()) {
-            currNode = node;
+            context.setCurrentNode(node);
             try {
                 node.init(context);
             } finally {
-                currNode = null;
+                context.setCurrentNode(null);
             }
         }
+
     }
 
     public ProcessorContext context() {
@@ -225,6 +238,10 @@ public class KStreamTestDriver {
 
     }
 
+    public void setCurrentNode(final ProcessorNode currentNode) {
+        currNode = currentNode;
+    }
+
 
     private class MockRecordCollector extends RecordCollector {
         public MockRecordCollector() {

http://git-wip-us.apache.org/repos/asf/kafka/blob/c1fb615a/streams/src/test/java/org/apache/kafka/test/MockProcessorContext.java
----------------------------------------------------------------------
diff --git 
a/streams/src/test/java/org/apache/kafka/test/MockProcessorContext.java 
b/streams/src/test/java/org/apache/kafka/test/MockProcessorContext.java
index 8ad2fa9..cafdd9e 100644
--- a/streams/src/test/java/org/apache/kafka/test/MockProcessorContext.java
+++ b/streams/src/test/java/org/apache/kafka/test/MockProcessorContext.java
@@ -51,6 +51,7 @@ public class MockProcessorContext implements 
InternalProcessorContext, RecordCol
 
     long timestamp = -1L;
     private RecordContext recordContext;
+    private ProcessorNode currentNode;
 
     public MockProcessorContext(StateSerdes<?, ?> serdes, RecordCollector 
collector) {
         this(null, null, serdes.keySerde(), serdes.valueSerde(), collector, 
null);
@@ -248,7 +249,13 @@ public class MockProcessorContext implements 
InternalProcessorContext, RecordCol
 
     @Override
     public void setCurrentNode(final ProcessorNode currentNode) {
+        this.currentNode  = currentNode;
+        driver.setCurrentNode(currentNode);
+    }
 
+    @Override
+    public ProcessorNode currentNode() {
+        return currentNode;
     }
 
 }

Reply via email to