[FLINK-8210][network-tests] Collect results into proper mock in MockEnvironment
Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/af6bdb60 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/af6bdb60 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/af6bdb60 Branch: refs/heads/master Commit: af6bdb606e825d0d66ba532bcb9d8335f9f4c54b Parents: d5d4da1 Author: Piotr Nowojski <piotr.nowoj...@gmail.com> Authored: Tue Dec 5 09:36:55 2017 +0100 Committer: Stefan Richter <s.rich...@data-artisans.com> Committed: Mon Jan 8 11:46:00 2018 +0100 ---------------------------------------------------------------------- .../RecordCollectingResultPartitionWriter.java | 88 ++++++++++++++++++++ .../operators/testutils/MockEnvironment.java | 65 +-------------- 2 files changed, 92 insertions(+), 61 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/af6bdb60/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordCollectingResultPartitionWriter.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordCollectingResultPartitionWriter.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordCollectingResultPartitionWriter.java new file mode 100644 index 0000000..24ccae1 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordCollectingResultPartitionWriter.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.api.writer; + +import org.apache.flink.runtime.io.network.api.serialization.AdaptiveSpanningRecordDeserializer; +import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.BufferProvider; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.types.Record; + +import java.io.IOException; +import java.util.List; + +import static org.apache.flink.util.Preconditions.checkNotNull; +import static org.apache.flink.util.Preconditions.checkState; + +/** + * {@link ResultPartitionWriter} that collects output on the List. + */ +public class RecordCollectingResultPartitionWriter implements ResultPartitionWriter { + private final List<Record> output; + private final BufferProvider bufferProvider; + + private final Record record = new Record(); + private final RecordDeserializer<Record> deserializer = new AdaptiveSpanningRecordDeserializer<>(); + + public RecordCollectingResultPartitionWriter(List<Record> output, BufferProvider bufferProvider) { + this.output = checkNotNull(output); + this.bufferProvider = checkNotNull(bufferProvider); + } + + @Override + public BufferProvider getBufferProvider() { + return bufferProvider; + } + + @Override + public ResultPartitionID getPartitionId() { + return new ResultPartitionID(); + } + + @Override + public int getNumberOfSubpartitions() { + return 1; + } + + @Override + public int getNumTargetKeyGroups() { + return 1; + } + + @Override + public void writeBuffer(Buffer buffer, int targetChannel) throws IOException { + checkState(targetChannel < getNumberOfSubpartitions()); + + deserializer.setNextBuffer(buffer); + + while (deserializer.hasUnfinishedData()) { + RecordDeserializer.DeserializationResult result = deserializer.getNextRecord(record); + + if (result.isFullRecord()) { + output.add(record.createCopy()); + } + + if (result == RecordDeserializer.DeserializationResult.LAST_RECORD_FROM_BUFFER + || result == RecordDeserializer.DeserializationResult.PARTIAL_RECORD) { + break; + } + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/af6bdb60/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java index 2fdddb5..bc5677e 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java @@ -23,7 +23,6 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.TaskInfo; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.fs.Path; -import org.apache.flink.core.memory.MemorySegmentFactory; import org.apache.flink.runtime.accumulators.AccumulatorRegistry; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; @@ -33,14 +32,11 @@ import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.disk.iomanager.IOManagerAsync; import org.apache.flink.runtime.io.network.TaskEventDispatcher; -import org.apache.flink.runtime.io.network.api.serialization.AdaptiveSpanningRecordDeserializer; -import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer; +import org.apache.flink.runtime.io.network.api.writer.RecordCollectingResultPartitionWriter; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; -import org.apache.flink.runtime.io.network.buffer.Buffer; -import org.apache.flink.runtime.io.network.buffer.BufferProvider; -import org.apache.flink.runtime.io.network.buffer.BufferRecycler; import org.apache.flink.runtime.io.network.partition.consumer.InputGate; import org.apache.flink.runtime.io.network.partition.consumer.IteratorWrappingTestSingleInputGate; +import org.apache.flink.runtime.io.network.util.TestPooledBufferProvider; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; import org.apache.flink.runtime.memory.MemoryManager; @@ -54,9 +50,6 @@ import org.apache.flink.types.Record; import org.apache.flink.util.MutableObjectIterator; import org.apache.flink.util.Preconditions; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; - import java.util.Collections; import java.util.LinkedList; import java.util.List; @@ -65,11 +58,6 @@ import java.util.concurrent.Future; import static org.apache.flink.util.Preconditions.checkState; import static org.junit.Assert.fail; -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.anyInt; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; /** * IMPORTANT! Remember to close environment after usage! @@ -106,7 +94,7 @@ public class MockEnvironment implements Environment, AutoCloseable { private final ClassLoader userCodeClassLoader; - private TaskEventDispatcher taskEventDispatcher = mock(TaskEventDispatcher.class); + private final TaskEventDispatcher taskEventDispatcher = new TaskEventDispatcher(); public MockEnvironment(String taskName, long memorySize, MockInputSplitProvider inputSplitProvider, int bufferSize) { this(taskName, memorySize, inputSplitProvider, bufferSize, new Configuration(), new ExecutionConfig()); @@ -196,52 +184,7 @@ public class MockEnvironment implements Environment, AutoCloseable { public void addOutput(final List<Record> outputList) { try { - // The record-oriented writers wrap the buffer writer. We mock it - // to collect the returned buffers and deserialize the content to - // the output list - BufferProvider mockBufferProvider = mock(BufferProvider.class); - when(mockBufferProvider.requestBufferBlocking()).thenAnswer(new Answer<Buffer>() { - - @Override - public Buffer answer(InvocationOnMock invocationOnMock) throws Throwable { - return new Buffer(MemorySegmentFactory.allocateUnpooledSegment(bufferSize), mock(BufferRecycler.class)); - } - }); - - ResultPartitionWriter mockWriter = mock(ResultPartitionWriter.class); - when(mockWriter.getNumberOfSubpartitions()).thenReturn(1); - when(mockWriter.getBufferProvider()).thenReturn(mockBufferProvider); - - final Record record = new Record(); - final RecordDeserializer<Record> deserializer = new AdaptiveSpanningRecordDeserializer<Record>(); - - // Add records from the buffer to the output list - doAnswer(new Answer<Void>() { - - @Override - public Void answer(InvocationOnMock invocationOnMock) throws Throwable { - Buffer buffer = (Buffer) invocationOnMock.getArguments()[0]; - - deserializer.setNextBuffer(buffer); - - while (deserializer.hasUnfinishedData()) { - RecordDeserializer.DeserializationResult result = deserializer.getNextRecord(record); - - if (result.isFullRecord()) { - outputList.add(record.createCopy()); - } - - if (result == RecordDeserializer.DeserializationResult.LAST_RECORD_FROM_BUFFER - || result == RecordDeserializer.DeserializationResult.PARTIAL_RECORD) { - break; - } - } - - return null; - } - }).when(mockWriter).writeBuffer(any(Buffer.class), anyInt()); - - outputs.add(mockWriter); + outputs.add(new RecordCollectingResultPartitionWriter(outputList, new TestPooledBufferProvider(Integer.MAX_VALUE))); } catch (Throwable t) { t.printStackTrace();