diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java index f08d8b0f984cf..43c3d23b6304d 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java @@ -90,7 +90,6 @@ protected void channelRead0( ManagedBuffer buf; try { streamManager.checkAuthorization(client, msg.streamChunkId.streamId); - streamManager.registerChannel(channel, msg.streamChunkId.streamId); buf = streamManager.getChunk(msg.streamChunkId.streamId, msg.streamChunkId.chunkIndex); } catch (Exception e) { logger.error(String.format("Error opening block %s for request from %s", diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index 0f6a8824d95e5..6fafcc131fa24 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -23,6 +23,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import io.netty.channel.Channel; import org.apache.commons.lang3.tuple.ImmutablePair; @@ -49,7 +50,7 @@ final Iterator<ManagedBuffer> buffers; // The channel associated to the stream - Channel associatedChannel = null; + final Channel associatedChannel; // Used to keep track of the index of the buffer that the user has retrieved, just to ensure // that the caller only requests each chunk one at a time, in order. @@ -58,9 +59,10 @@ // Used to keep track of the number of chunks being transferred and not finished yet. volatile long chunksBeingTransferred = 0L; - StreamState(String appId, Iterator<ManagedBuffer> buffers) { + StreamState(String appId, Iterator<ManagedBuffer> buffers, Channel channel) { this.appId = appId; this.buffers = Preconditions.checkNotNull(buffers); + this.associatedChannel = channel; } } @@ -71,13 +73,6 @@ public OneForOneStreamManager() { streams = new ConcurrentHashMap<>(); } - @Override - public void registerChannel(Channel channel, long streamId) { - if (streams.containsKey(streamId)) { - streams.get(streamId).associatedChannel = channel; - } - } - @Override public ManagedBuffer getChunk(long streamId, int chunkIndex) { StreamState state = streams.get(streamId); @@ -195,11 +190,19 @@ public long chunksBeingTransferred() { * * If an app ID is provided, only callers who've authenticated with the given app ID will be * allowed to fetch from this stream. + * + * This method also associates the stream with a single client connection, which is guaranteed + * to be the only reader of the stream. Once the connection is closed, the stream will never + * be used again, enabling cleanup by `connectionTerminated`. */ - public long registerStream(String appId, Iterator<ManagedBuffer> buffers) { + public long registerStream(String appId, Iterator<ManagedBuffer> buffers, Channel channel) { long myStreamId = nextStreamId.getAndIncrement(); - streams.put(myStreamId, new StreamState(appId, buffers)); + streams.put(myStreamId, new StreamState(appId, buffers, channel)); return myStreamId; } + @VisibleForTesting + public int numStreamStates() { + return streams.size(); + } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java index c535295831606..e48d27be1126a 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java @@ -60,16 +60,6 @@ public ManagedBuffer openStream(String streamId) { throw new UnsupportedOperationException(); } - /** - * Associates a stream with a single client connection, which is guaranteed to be the only reader - * of the stream. The getChunk() method will be called serially on this connection and once the - * connection is closed, the stream will never be used again, enabling cleanup. - * - * This must be called before the first getChunk() on the stream, but it may be invoked multiple - * times with the same channel and stream id. - */ - public void registerChannel(Channel channel, long streamId) { } - /** * Indicates that the given channel has been terminated. After this occurs, we are guaranteed not * to read from the associated streams again, so any state can be cleaned up. diff --git a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java index 2c72c53a33ae8..6c9239606bb8c 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java @@ -64,8 +64,7 @@ public void handleChunkFetchRequest() throws Exception { managedBuffers.add(new TestManagedBuffer(20)); managedBuffers.add(new TestManagedBuffer(30)); managedBuffers.add(new TestManagedBuffer(40)); - long streamId = streamManager.registerStream("test-app", managedBuffers.iterator()); - streamManager.registerChannel(channel, streamId); + long streamId = streamManager.registerStream("test-app", managedBuffers.iterator(), channel); TransportClient reverseClient = mock(TransportClient.class); ChunkFetchRequestHandler requestHandler = new ChunkFetchRequestHandler(reverseClient, rpcHandler.getStreamManager(), 2L); diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java index ad640415a8e6d..a87f6c11a2bfd 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java @@ -58,8 +58,10 @@ public void handleStreamRequest() throws Exception { managedBuffers.add(new TestManagedBuffer(20)); managedBuffers.add(new TestManagedBuffer(30)); managedBuffers.add(new TestManagedBuffer(40)); - long streamId = streamManager.registerStream("test-app", managedBuffers.iterator()); - streamManager.registerChannel(channel, streamId); + long streamId = streamManager.registerStream("test-app", managedBuffers.iterator(), channel); + + assert streamManager.numStreamStates() == 1; + TransportClient reverseClient = mock(TransportClient.class); TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient, rpcHandler, 2L); @@ -94,5 +96,8 @@ public void handleStreamRequest() throws Exception { requestHandler.handle(request3); verify(channel, times(1)).close(); assert responseAndPromisePairs.size() == 3; + + streamManager.connectionTerminated(channel); + assert streamManager.numStreamStates() == 0; } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java index c647525d8f1bd..4248762c32389 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java @@ -37,14 +37,15 @@ public void managedBuffersAreFeedWhenConnectionIsClosed() throws Exception { TestManagedBuffer buffer2 = Mockito.spy(new TestManagedBuffer(20)); buffers.add(buffer1); buffers.add(buffer2); - long streamId = manager.registerStream("appId", buffers.iterator()); Channel dummyChannel = Mockito.mock(Channel.class, Mockito.RETURNS_SMART_NULLS); - manager.registerChannel(dummyChannel, streamId); + manager.registerStream("appId", buffers.iterator(), dummyChannel); + assert manager.numStreamStates() == 1; manager.connectionTerminated(dummyChannel); Mockito.verify(buffer1, Mockito.times(1)).release(); Mockito.verify(buffer2, Mockito.times(1)).release(); + assert manager.numStreamStates() == 0; } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index 788a845c57755..b25e48a164e6b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -92,7 +92,7 @@ protected void handleMessage( OpenBlocks msg = (OpenBlocks) msgObj; checkAuth(client, msg.appId); long streamId = streamManager.registerStream(client.getClientId(), - new ManagedBufferIterator(msg.appId, msg.execId, msg.blockIds)); + new ManagedBufferIterator(msg.appId, msg.execId, msg.blockIds), client.getChannel()); if (logger.isTraceEnabled()) { logger.trace("Registered streamId {} with {} buffers for client {} from host {}", streamId, diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index 4cc9a16e1449f..537c277cd26b5 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -103,7 +103,8 @@ public void testOpenShuffleBlocks() { @SuppressWarnings("unchecked") ArgumentCaptor<Iterator<ManagedBuffer>> stream = (ArgumentCaptor<Iterator<ManagedBuffer>>) (ArgumentCaptor<?>) ArgumentCaptor.forClass(Iterator.class); - verify(streamManager, times(1)).registerStream(anyString(), stream.capture()); + verify(streamManager, times(1)).registerStream(anyString(), stream.capture(), + any()); Iterator<ManagedBuffer> buffers = stream.getValue(); assertEquals(block0Marker, buffers.next()); assertEquals(block1Marker, buffers.next()); diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 7076701421e2e..27f4f94ea55f8 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -59,7 +59,8 @@ class NettyBlockRpcServer( val blocksNum = openBlocks.blockIds.length val blocks = for (i <- (0 until blocksNum).view) yield blockManager.getBlockData(BlockId.apply(openBlocks.blockIds(i))) - val streamId = streamManager.registerStream(appId, blocks.iterator.asJava) + val streamId = streamManager.registerStream(appId, blocks.iterator.asJava, + client.getChannel) logTrace(s"Registered streamId $streamId with $blocksNum buffers") responseContext.onSuccess(new StreamHandle(streamId, blocksNum).toByteBuffer)
With regards, Apache Git Services --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org