This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git

commit 9a15fb43b9e3aaa9322fb9e5f37b181f1813fda7
Author: Yun Gao <gaoyunhen...@gmail.com>
AuthorDate: Wed Oct 6 18:36:29 2021 +0800

    [hotfix][iteration] Simplify the head operator test
---
 .../flink/iteration/operator/HeadOperatorTest.java | 281 ++++++++++++---------
 1 file changed, 159 insertions(+), 122 deletions(-)

diff --git 
a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java
 
b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java
index f54422e..3ded30a 100644
--- 
a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java
+++ 
b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java
@@ -25,9 +25,9 @@ import 
org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
 import org.apache.flink.iteration.operator.event.SubtaskAlignedEvent;
 import 
org.apache.flink.iteration.operator.headprocessor.RegularHeadOperatorRecordProcessor;
 import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.io.network.api.EndOfData;
 import org.apache.flink.runtime.jobgraph.OperatorID;
-import 
org.apache.flink.runtime.operators.coordination.MockOperatorEventGateway;
 import org.apache.flink.runtime.operators.coordination.OperatorEvent;
 import org.apache.flink.runtime.operators.coordination.OperatorEventGateway;
 import org.apache.flink.statefun.flink.core.feedback.FeedbackChannel;
@@ -38,11 +38,16 @@ import 
org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.OneInputStreamTask;
 import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarness;
 import 
org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarnessBuilder;
+import org.apache.flink.util.FlinkException;
 import org.apache.flink.util.SerializedValue;
 import org.apache.flink.util.TestLogger;
+import org.apache.flink.util.function.FunctionWithException;
 
 import org.junit.Test;
 
+import javax.annotation.Nullable;
+
+import java.io.IOException;
 import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -62,144 +67,173 @@ public class HeadOperatorTest extends TestLogger {
     @Test
     public void testForwardRecords() throws Exception {
         IterationID iterationId = new IterationID();
-        try (StreamTaskMailboxTestHarness<IterationRecord<Integer>> harness =
-                new StreamTaskMailboxTestHarnessBuilder<>(
-                                OneInputStreamTask::new,
-                                new 
IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
-                        .addInput(new 
IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
-                        .setupOutputForSingletonOperatorChain(
-                                new RecordingHeadOperatorFactory(
-                                        iterationId, 0, false, 5, 
MockOperatorEventGateway::new))
-                        .build()) {
-            harness.processElement(new 
StreamRecord<>(IterationRecord.newRecord(1, 0), 2));
-            putFeedbackRecords(
-                    iterationId, 0, new 
StreamRecord<>(IterationRecord.newRecord(3, 1), 3));
-            harness.processAll();
-            harness.processElement(new 
StreamRecord<>(IterationRecord.newRecord(2, 0), 3));
-            putFeedbackRecords(
-                    iterationId, 0, new 
StreamRecord<>(IterationRecord.newRecord(4, 1), 4));
-            harness.processAll();
-
-            List<StreamRecord<IterationRecord<Integer>>> expectedOutput =
-                    Arrays.asList(
-                            new StreamRecord<>(IterationRecord.newRecord(1, 
0), 2),
-                            new StreamRecord<>(IterationRecord.newRecord(3, 
1), 3),
-                            new StreamRecord<>(IterationRecord.newRecord(2, 
0), 3),
-                            new StreamRecord<>(IterationRecord.newRecord(4, 
1), 4));
-            assertEquals(expectedOutput, new ArrayList<>(harness.getOutput()));
-
-            RegularHeadOperatorRecordProcessor recordProcessor =
-                    (RegularHeadOperatorRecordProcessor)
-                            
RecordingHeadOperatorFactory.latestHeadOperator.getRecordProcessor();
-
-            assertEquals(2, (long) 
recordProcessor.getNumFeedbackRecordsPerEpoch().get(1));
-        }
+        OperatorID operatorId = new OperatorID();
+
+        createHarnessAndRun(
+                iterationId,
+                operatorId,
+                null,
+                harness -> {
+                    harness.processElement(new 
StreamRecord<>(IterationRecord.newRecord(1, 0), 2));
+                    putFeedbackRecords(iterationId, 
IterationRecord.newRecord(3, 1), 3L);
+                    harness.processAll();
+                    harness.processElement(new 
StreamRecord<>(IterationRecord.newRecord(2, 0), 3));
+                    putFeedbackRecords(iterationId, 
IterationRecord.newRecord(4, 1), 4L);
+                    harness.processAll();
+
+                    List<StreamRecord<IterationRecord<Integer>>> 
expectedOutput =
+                            Arrays.asList(
+                                    new 
StreamRecord<>(IterationRecord.newRecord(1, 0), 2),
+                                    new 
StreamRecord<>(IterationRecord.newRecord(3, 1), 3),
+                                    new 
StreamRecord<>(IterationRecord.newRecord(2, 0), 3),
+                                    new 
StreamRecord<>(IterationRecord.newRecord(4, 1), 4));
+                    assertEquals(expectedOutput, new 
ArrayList<>(harness.getOutput()));
+
+                    RegularHeadOperatorRecordProcessor recordProcessor =
+                            (RegularHeadOperatorRecordProcessor)
+                                    
RecordingHeadOperatorFactory.latestHeadOperator
+                                            .getRecordProcessor();
+
+                    assertEquals(2, (long) 
recordProcessor.getNumFeedbackRecordsPerEpoch().get(1));
+
+                    return null;
+                });
     }
 
     @Test(timeout = 60000)
     public void testSynchronizingEpochWatermark() throws Exception {
         IterationID iterationId = new IterationID();
+        OperatorID operatorId = new OperatorID();
+
+        createHarnessAndRun(
+                iterationId,
+                operatorId,
+                null,
+                harness -> {
+                    harness.processElement(new 
StreamRecord<>(IterationRecord.newRecord(1, 0), 2));
+
+                    // We will start a new thread to simulate the operator 
coordinator thread
+                    CompletableFuture<Void> taskExecuteResult =
+                            CompletableFuture.supplyAsync(
+                                    () -> {
+                                        try {
+                                            RecordingOperatorEventGateway 
eventGateway =
+                                                    
(RecordingOperatorEventGateway)
+                                                            
RecordingHeadOperatorFactory
+                                                                    
.latestHeadOperator
+                                                                    
.getOperatorEventGateway();
+
+                                            // We should get the aligned event 
for round 0 on
+                                            // endInput
+                                            assertNextOperatorEvent(
+                                                    new SubtaskAlignedEvent(0, 
0, false),
+                                                    eventGateway);
+                                            dispatchOperatorEvent(
+                                                    harness,
+                                                    operatorId,
+                                                    new 
GloballyAlignedEvent(0, false));
+
+                                            putFeedbackRecords(
+                                                    iterationId,
+                                                    
IterationRecord.newRecord(4, 1),
+                                                    4L);
+                                            putFeedbackRecords(
+                                                    iterationId,
+                                                    
IterationRecord.newEpochWatermark(1, "tail"),
+                                                    0L);
+
+                                            assertNextOperatorEvent(
+                                                    new SubtaskAlignedEvent(1, 
1, false),
+                                                    eventGateway);
+                                            dispatchOperatorEvent(
+                                                    harness,
+                                                    operatorId,
+                                                    new 
GloballyAlignedEvent(1, true));
+
+                                            while 
(RecordingHeadOperatorFactory.latestHeadOperator
+                                                            .getStatus()
+                                                    == 
HeadOperator.HeadOperatorStatus.RUNNING) ;
+                                            putFeedbackRecords(
+                                                    iterationId,
+                                                    
IterationRecord.newEpochWatermark(
+                                                            Integer.MAX_VALUE 
+ 1, "tail"),
+                                                    null);
+
+                                            return null;
+                                        } catch (Throwable e) {
+                                            
RecordingHeadOperatorFactory.latestHeadOperator
+                                                    .getMailboxExecutor()
+                                                    .execute(
+                                                            () -> {
+                                                                throw e;
+                                                            },
+                                                            "poison mail");
+                                            throw new CompletionException(e);
+                                        }
+                                    });
+
+                    // Mark the input as finished.
+                    harness.processEvent(EndOfData.INSTANCE);
+
+                    // There should be no exception
+                    taskExecuteResult.get();
+
+                    assertEquals(
+                            Arrays.asList(
+                                    new 
StreamRecord<>(IterationRecord.newRecord(1, 0), 2),
+                                    new StreamRecord<>(
+                                            IterationRecord.newEpochWatermark(
+                                                    0,
+                                                    
OperatorUtils.getUniqueSenderId(operatorId, 0)),
+                                            0),
+                                    new 
StreamRecord<>(IterationRecord.newRecord(4, 1), 4),
+                                    new StreamRecord<>(
+                                            IterationRecord.newEpochWatermark(
+                                                    Integer.MAX_VALUE,
+                                                    
OperatorUtils.getUniqueSenderId(operatorId, 0)),
+                                            0)),
+                            new ArrayList<>(harness.getOutput()));
+                    return null;
+                });
+    }
+
+    private <T> T createHarnessAndRun(
+            IterationID iterationId,
+            OperatorID operatorId,
+            @Nullable TaskStateSnapshot snapshot,
+            FunctionWithException<
+                            
StreamTaskMailboxTestHarness<IterationRecord<Integer>>, T, Exception>
+                    runnable)
+            throws Exception {
         try (StreamTaskMailboxTestHarness<IterationRecord<Integer>> harness =
                 new StreamTaskMailboxTestHarnessBuilder<>(
                                 OneInputStreamTask::new,
                                 new 
IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
                         .addInput(new 
IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
+                        .setTaskStateSnapshot(
+                                1, snapshot == null ? new TaskStateSnapshot() 
: snapshot)
                         .setupOutputForSingletonOperatorChain(
                                 new RecordingHeadOperatorFactory(
                                         iterationId,
                                         0,
                                         false,
                                         5,
-                                        RecordingOperatorEventGateway::new))
+                                        RecordingOperatorEventGateway::new),
+                                operatorId)
                         .build()) {
-
-            OperatorID operatorId = 
RecordingHeadOperatorFactory.latestHeadOperator.getOperatorID();
-            harness.processElement(new 
StreamRecord<>(IterationRecord.newRecord(1, 0), 2));
-
-            // We will start a new thread to simulate the operator coordinator 
thread
-            CompletableFuture<Void> taskExecuteResult =
-                    CompletableFuture.supplyAsync(
-                            () -> {
-                                try {
-                                    RecordingOperatorEventGateway eventGateway 
=
-                                            (RecordingOperatorEventGateway)
-                                                    
RecordingHeadOperatorFactory.latestHeadOperator
-                                                            
.getOperatorEventGateway();
-
-                                    // We should get the aligned event for 
round 0 on endInput
-                                    assertNextOperatorEvent(
-                                            new SubtaskAlignedEvent(0, 0, 
false), eventGateway);
-                                    harness.getStreamTask()
-                                            .dispatchOperatorEvent(
-                                                    operatorId,
-                                                    new SerializedValue<>(
-                                                            new 
GloballyAlignedEvent(0, false)));
-
-                                    putFeedbackRecords(
-                                            iterationId,
-                                            0,
-                                            new 
StreamRecord<>(IterationRecord.newRecord(4, 1), 4));
-                                    putFeedbackRecords(
-                                            iterationId,
-                                            0,
-                                            new StreamRecord<>(
-                                                    
IterationRecord.newEpochWatermark(1, "tail"),
-                                                    0));
-
-                                    assertNextOperatorEvent(
-                                            new SubtaskAlignedEvent(1, 1, 
false), eventGateway);
-                                    harness.getStreamTask()
-                                            .dispatchOperatorEvent(
-                                                    operatorId,
-                                                    new SerializedValue<>(
-                                                            new 
GloballyAlignedEvent(1, true)));
-
-                                    while 
(RecordingHeadOperatorFactory.latestHeadOperator
-                                                    .getStatus()
-                                            == 
HeadOperator.HeadOperatorStatus.RUNNING) {}
-                                    putFeedbackRecords(
-                                            iterationId,
-                                            0,
-                                            new StreamRecord<>(
-                                                    
IterationRecord.newEpochWatermark(
-                                                            Integer.MAX_VALUE 
+ 1, "tail")));
-
-                                    return null;
-                                } catch (Throwable e) {
-                                    
RecordingHeadOperatorFactory.latestHeadOperator
-                                            .getMailboxExecutor()
-                                            .execute(
-                                                    () -> {
-                                                        throw e;
-                                                    },
-                                                    "poison mail");
-                                    throw new CompletionException(e);
-                                }
-                            });
-
-            // Mark the input as finished.
-            harness.processEvent(EndOfData.INSTANCE);
-
-            // There should be no exception
-            taskExecuteResult.get();
-
-            assertEquals(
-                    Arrays.asList(
-                            new StreamRecord<>(IterationRecord.newRecord(1, 
0), 2),
-                            new StreamRecord<>(
-                                    IterationRecord.newEpochWatermark(
-                                            0, 
OperatorUtils.getUniqueSenderId(operatorId, 0)),
-                                    0),
-                            new StreamRecord<>(IterationRecord.newRecord(4, 
1), 4),
-                            new StreamRecord<>(
-                                    IterationRecord.newEpochWatermark(
-                                            Integer.MAX_VALUE,
-                                            
OperatorUtils.getUniqueSenderId(operatorId, 0)),
-                                    0)),
-                    new ArrayList<>(harness.getOutput()));
+            return runnable.apply(harness);
         }
     }
 
+    private static void dispatchOperatorEvent(
+            StreamTaskMailboxTestHarness<?> harness,
+            OperatorID operatorId,
+            OperatorEvent operatorEvent)
+            throws IOException, FlinkException {
+        harness.getStreamTask()
+                .dispatchOperatorEvent(operatorId, new 
SerializedValue<>(operatorEvent));
+    }
+
     private static void assertNextOperatorEvent(
             OperatorEvent expectedEvent, RecordingOperatorEventGateway 
eventGateway)
             throws InterruptedException {
@@ -209,14 +243,17 @@ public class HeadOperatorTest extends TestLogger {
     }
 
     private static void putFeedbackRecords(
-            IterationID iterationId, int feedbackIndex, 
StreamRecord<IterationRecord<?>> record) {
+            IterationID iterationId, IterationRecord<?> record, @Nullable Long 
timestamp) {
         FeedbackChannel<StreamRecord<IterationRecord<?>>> feedbackChannel =
                 FeedbackChannelBroker.get()
                         .getChannel(
                                 
OperatorUtils.<StreamRecord<IterationRecord<?>>>createFeedbackKey(
-                                                iterationId, feedbackIndex)
+                                                iterationId, 0)
                                         .withSubTaskIndex(0, 0));
-        feedbackChannel.put(record);
+        feedbackChannel.put(
+                timestamp == null
+                        ? new StreamRecord<>(record)
+                        : new StreamRecord<>(record, timestamp));
     }
 
     private static class RecordingOperatorEventGateway implements 
OperatorEventGateway {

Reply via email to