Repository: flink
Updated Branches:
  refs/heads/master f37507d94 -> 050f9a416


[FLINK-5864] [cep] Fix duplicate output patterns problem

The problem was that the Dewey numbers in the
NFA::computeNextStates() were not updated correctly
when branching, i.e. from the same state  we had
two new valid states.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/050f9a41
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/050f9a41
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/050f9a41

Branch: refs/heads/master
Commit: 050f9a416173f74022a1fd10eae2bf502ae1c2ae
Parents: f37507d
Author: Dawid Wysakowicz <[email protected]>
Authored: Mon Feb 27 16:10:21 2017 +0100
Committer: kl0u <[email protected]>
Committed: Fri Mar 3 14:19:16 2017 +0100

----------------------------------------------------------------------
 .../main/java/org/apache/flink/cep/nfa/NFA.java | 44 ++++++++--
 .../org/apache/flink/cep/nfa/NFAITCase.java     | 85 ++++++++++++++++++++
 2 files changed, 122 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/050f9a41/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/nfa/NFA.java
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/nfa/NFA.java 
b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/nfa/NFA.java
index aefddb2..0ff496f 100644
--- a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/nfa/NFA.java
+++ b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/nfa/NFA.java
@@ -38,9 +38,11 @@ import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.LinkedList;
+import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Queue;
@@ -220,6 +222,22 @@ public class NFA<T> implements Serializable {
        }
 
        /**
+        * Comparator used for imposing the assumption that IGNORE is always 
the last StateTransition in a state.
+        */
+       private interface StateTransitionComparator<T> extends 
Comparator<StateTransition<T>>, Serializable {}
+       private final Comparator<StateTransition<T>> stateTransitionComparator 
= new StateTransitionComparator<T>() {
+               private static final long serialVersionUID = 
-2775474935413622278L;
+
+               @Override
+               public int compare(final StateTransition<T> o1, final 
StateTransition<T> o2) {
+                       if (o1.getAction() == o2.getAction()) {
+                               return 0;
+                       }
+                       return o1.getAction() == StateTransitionAction.IGNORE ? 
1 : -1;
+               }
+       };
+
+       /**
         * Computes the next computation states based on the given computation 
state, the current event,
         * its timestamp and the internal state machine.
         *
@@ -238,9 +256,13 @@ public class NFA<T> implements Serializable {
 
                states.push(state);
 
+               boolean branched = false;
                while (!states.isEmpty()) {
                        State<T> currentState = states.pop();
-                       Collection<StateTransition<T>> stateTransitions = 
currentState.getStateTransitions();
+                       final List<StateTransition<T>> stateTransitions = new 
ArrayList<>(currentState.getStateTransitions());
+
+                       // impose the IGNORE will be processed last
+                       Collections.sort(stateTransitions, 
stateTransitionComparator);
 
                        // check all state transitions for each state
                        for (StateTransition<T> stateTransition: 
stateTransitions) {
@@ -254,7 +276,18 @@ public class NFA<T> implements Serializable {
                                                                
states.push(stateTransition.getTargetState());
                                                                break;
                                                        case IGNORE:
-                                                               
resultingComputationStates.add(computationState);
+                                                               final 
DeweyNumber version;
+                                                               if (branched) {
+                                                                       version 
= computationState.getVersion().increase();
+                                                               } else {
+                                                                       version 
= computationState.getVersion();
+                                                               }
+                                                               
resultingComputationStates.add(new ComputationState<T>(
+                                                                       
computationState.getState(),
+                                                                       
computationState.getEvent(),
+                                                                       
computationState.getTimestamp(),
+                                                                       version,
+                                                                       
computationState.getStartTimestamp()));
 
                                                                // we have a 
new computation state referring to the same the shared entry
                                                                // the lock of 
the current computation is released later on
@@ -280,11 +313,8 @@ public class NFA<T> implements Serializable {
                                                                        
previousTimestamp = computationState.getTimestamp();
                                                                        
oldVersion = computationState.getVersion();
 
-                                                                       if 
(newState.equals(computationState.getState())) {
-                                                                               
newComputationStateVersion = oldVersion.increase();
-                                                                       } else {
-                                                                               
newComputationStateVersion = oldVersion.addStage();
-                                                                       }
+                                                                       
branched = true;
+                                                                       
newComputationStateVersion = oldVersion.addStage();
                                                                }
 
                                                                if 
(previousState.isStart()) {

http://git-wip-us.apache.org/repos/asf/flink/blob/050f9a41/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/NFAITCase.java
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/NFAITCase.java
 
b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/NFAITCase.java
index 4a33c1e..ccae848 100644
--- 
a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/NFAITCase.java
+++ 
b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/nfa/NFAITCase.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.cep.nfa;
 
+import com.google.common.collect.Sets;
 import org.apache.flink.api.common.functions.FilterFunction;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.cep.Event;
@@ -244,4 +245,88 @@ public class NFAITCase extends TestLogger {
                assertEquals(expectedTimeoutPatterns, resultingTimeoutPatterns);
        }
 
+       @Test
+       public void testBranchingPattern() {
+               List<StreamRecord<Event>> inputEvents = new ArrayList<>();
+
+               Event startEvent = new Event(40, "start", 1.0);
+               SubEvent middleEvent1 = new SubEvent(41, "foo1", 1.0, 10.0);
+               SubEvent middleEvent2 = new SubEvent(42, "foo2", 1.0, 10.0);
+               SubEvent middleEvent3 = new SubEvent(43, "foo3", 1.0, 10.0);
+               SubEvent nextOne1 = new SubEvent(44, "next-one", 1.0, 2.0);
+               SubEvent nextOne2 = new SubEvent(45, "next-one", 1.0, 2.0);
+               Event endEvent=  new Event(46, "end", 1.0);
+
+               inputEvents.add(new StreamRecord<Event>(startEvent, 1));
+               inputEvents.add(new StreamRecord<Event>(middleEvent1, 3));
+               inputEvents.add(new StreamRecord<Event>(middleEvent2, 4));
+               inputEvents.add(new StreamRecord<Event>(middleEvent3, 5));
+               inputEvents.add(new StreamRecord<Event>(nextOne1, 6));
+               inputEvents.add(new StreamRecord<Event>(nextOne2, 7));
+               inputEvents.add(new StreamRecord<Event>(endEvent, 8));
+
+               Pattern<Event, ?> pattern = 
Pattern.<Event>begin("start").where(new FilterFunction<Event>() {
+                       private static final long serialVersionUID = 
5726188262756267490L;
+
+                       @Override
+                       public boolean filter(Event value) throws Exception {
+                               return value.getName().equals("start");
+                       }
+               })
+                       
.followedBy("middle-first").subtype(SubEvent.class).where(new 
FilterFunction<SubEvent>() {
+                               private static final long serialVersionUID = 
6215754202506583964L;
+
+                               @Override
+                               public boolean filter(SubEvent value) throws 
Exception {
+                                       return value.getVolume() > 5.0;
+                               }
+                       })
+                       
.followedBy("middle-second").subtype(SubEvent.class).where(new 
FilterFunction<SubEvent>() {
+                               private static final long serialVersionUID = 
6215754202506583964L;
+
+                               @Override
+                               public boolean filter(SubEvent value) throws 
Exception {
+                                       return 
value.getName().equals("next-one");
+                               }
+                       })
+                       .followedBy("end").where(new FilterFunction<Event>() {
+                               private static final long serialVersionUID = 
7056763917392056548L;
+
+                               @Override
+                               public boolean filter(Event value) throws 
Exception {
+                                       return value.getName().equals("end");
+                               }
+                       });
+
+               NFA<Event> nfa = NFACompiler.compile(pattern, 
Event.createTypeSerializer(), false);
+
+               List<Map<String, Event>> resultingPatterns = new ArrayList<>();
+
+               for (StreamRecord<Event> inputEvent: inputEvents) {
+                       Collection<Map<String, Event>> patterns = nfa.process(
+                               inputEvent.getValue(),
+                               inputEvent.getTimestamp()).f0;
+
+                       resultingPatterns.addAll(patterns);
+               }
+
+               assertEquals(6, resultingPatterns.size());
+
+               final Set<Set<Event>> patterns = new HashSet<>();
+               for (Map<String, Event> resultingPattern : resultingPatterns) {
+                       patterns.add(new HashSet<>(resultingPattern.values()));
+               }
+
+               assertEquals(Sets.newHashSet(
+                       Sets.newHashSet(startEvent, middleEvent1, nextOne1, 
endEvent),
+                       Sets.newHashSet(startEvent, middleEvent2, nextOne1, 
endEvent),
+                       Sets.newHashSet(startEvent, middleEvent3, nextOne1, 
endEvent),
+                       Sets.newHashSet(startEvent, middleEvent1, nextOne2, 
endEvent),
+                       Sets.newHashSet(startEvent, middleEvent2, nextOne2, 
endEvent),
+                       Sets.newHashSet(startEvent, middleEvent3, nextOne2, 
endEvent)
+               ), patterns);
+       }
+
+
+
 }

Reply via email to