mjsax commented on code in PR #13996: URL: https://github.com/apache/kafka/pull/13996#discussion_r1265999671
########## streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/Graph.java: ########## @@ -0,0 +1,367 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import java.util.HashMap; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.Set; +import java.util.SortedMap; +import java.util.SortedSet; +import java.util.TreeMap; +import java.util.TreeSet; +import java.util.stream.Collectors; + +public class Graph<V extends Comparable<V>> { + public class Edge { + final V destination; + final int capacity; + final int cost; + int residualFlow; + int flow; + Edge counterEdge; + boolean forwardEdge; + + public Edge(final V destination, final int capacity, final int cost, final int residualFlow, final int flow) { + this(destination, capacity, cost, residualFlow, flow, true); + } + + public Edge(final V destination, final int capacity, final int cost, final int residualFlow, final int flow, + final boolean forwardEdge) { Review Comment: nit: formatting (if it does not fit in one line, we should move each parameter into it's one line to simplify reading) ``` public Edge( final V destination, final int capacity, final int cost, final int residualFlow, final int flow, final boolean forwardEdge ) { ``` ########## streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/GraphTest.java: ########## @@ -0,0 +1,414 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.junit.Before; +import org.junit.Test; + +public class GraphTest { + private Graph<Integer> graph; + + @Before + public void setUp() { + /* + * Node 0 and 2 are both connected to node 1 and 3. There's a flow of 1 unit from 0 to 1 and 2 to + * 3. The total cost in this case is 5. Min cost should be 2 by flowing 1 unit from 0 to 3 and 2 + * to 1 + */ + graph = new Graph<>(); + graph.addEdge(0, 1, 1, 3, 1); + graph.addEdge(0, 3, 1, 1, 0); + graph.addEdge(2, 1, 1, 1, 0); + graph.addEdge(2, 3, 1, 2, 1); + graph.addEdge(4, 0, 1, 0, 1); + graph.addEdge(4, 2, 1, 0, 1); + graph.addEdge(1, 5, 1, 0, 1); + graph.addEdge(3, 5, 1, 0, 1); + graph.setSourceNode(4); + graph.setSinkNode(5); + } + + @Test + public void testBasic() { + final Set<Integer> nodes = graph.nodes(); + assertEquals(6, nodes.size()); + assertThat(nodes, contains(0, 1, 2, 3, 4, 5)); + + Map<Integer, Graph<Integer>.Edge> edges = graph.edges(0); + assertEquals(2, edges.size()); + assertEquals(getEdge(1, 1, 3, 0, 1), edges.get(1)); + assertEquals(getEdge(3, 1, 1, 1, 0), edges.get(3)); + + edges = graph.edges(2); + assertEquals(2, edges.size()); + assertEquals(getEdge(1, 1, 1, 1, 0), edges.get(1)); + assertEquals(getEdge(3, 1, 2, 0, 1), edges.get(3)); + + edges = graph.edges(1); + assertEquals(1, edges.size()); + assertEquals(getEdge(5, 1, 0, 0, 1), edges.get(5)); + + edges = graph.edges(3); + assertEquals(1, edges.size()); + assertEquals(getEdge(5, 1, 0, 0, 1), edges.get(5)); + + edges = graph.edges(4); + assertEquals(2, edges.size()); + assertEquals(getEdge(0, 1, 0, 0, 1), edges.get(0)); + assertEquals(getEdge(2, 1, 0, 0, 1), edges.get(2)); + + edges = graph.edges(5); + assertNull(edges); + + assertFalse(graph.isResidualGraph()); + } + + @Test + public void testResidualGraph() { + final Graph<Integer> residualGraph = graph.residualGraph(); + final Graph<Integer> residualGraph1 = residualGraph.residualGraph(); + assertSame(residualGraph1, residualGraph); + + final Set<Integer> nodes = residualGraph.nodes(); + assertEquals(6, nodes.size()); + assertThat(nodes, contains(0, 1, 2, 3, 4, 5)); + + Map<Integer, Graph<Integer>.Edge> edges = residualGraph.edges(0); + assertEquals(3, edges.size()); + assertEquals(getEdge(1, 1, 3, 0, 1), edges.get(1)); + assertEquals(getEdge(3, 1, 1, 1, 0), edges.get(3)); + assertEquals(getEdge(4, 1, 0, 1, 0, false), edges.get(4)); + + edges = residualGraph.edges(2); + assertEquals(3, edges.size()); + assertEquals(getEdge(1, 1, 1, 1, 0), edges.get(1)); + assertEquals(getEdge(3, 1, 2, 0, 1), edges.get(3)); + assertEquals(getEdge(4, 1, 0, 1, 0, false), edges.get(4)); + + edges = residualGraph.edges(1); + assertEquals(3, edges.size()); + assertEquals(getEdge(0, 1, -3, 1, 0, false), edges.get(0)); + assertEquals(getEdge(2, 1, -1, 0, 0, false), edges.get(2)); + assertEquals(getEdge(5, 1, 0, 0, 1), edges.get(5)); + + edges = residualGraph.edges(3); + assertEquals(3, edges.size()); + assertEquals(getEdge(0, 1, -1, 0, 0, false), edges.get(0)); + assertEquals(getEdge(2, 1, -2, 1, 0, false), edges.get(2)); + assertEquals(getEdge(5, 1, 0, 0, 1), edges.get(5)); + + assertTrue(residualGraph.isResidualGraph()); + } + + @Test + public void testInvalidOperation() { + final Graph<Integer> graph1 = new Graph<>(); + Exception exception = assertThrows(IllegalArgumentException.class, () -> graph1.addEdge(0, 1, -1, 0, 0)); + assertEquals("Edge capacity cannot be negative", exception.getMessage()); + + exception = assertThrows(IllegalArgumentException.class, () -> graph1.addEdge(0, 1, 1, 0, 2)); + assertEquals("Edge flow 2 cannot exceed capacity 1", exception.getMessage()); + + graph1.addEdge(0, 1, 1, 1, 1); + exception = assertThrows(IllegalArgumentException.class, () -> graph1.addEdge(1, 0, 1, 0, 0)); + assertEquals("There is already an edge from 0 to 1. Can not add an edge from 1 to 0 since " + + "there will create a cycle between two nodes", exception.getMessage()); + + final Graph<Integer> residualGraph = graph1.residualGraph(); + exception = assertThrows(IllegalStateException.class, residualGraph::solveMinCostFlow); + assertEquals("Should not be residual graph to solve min cost flow", exception.getMessage()); + + } + + @Test + public void testInvalidSource() { + final Graph<Integer> graph1 = new Graph<>(); + graph1.addEdge(0, 1, 1, 1, 0); + graph1.addEdge(1, 2, 1, 1, 0); + graph1.setSourceNode(1); + graph1.setSinkNode(2); + final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow); + assertEquals("Source node 1 shouldn't have input 0", exception.getMessage()); + } + + @Test + public void testInvalidSink() { + final Graph<Integer> graph1 = new Graph<>(); + graph1.addEdge(0, 1, 1, 1, 0); + graph1.addEdge(1, 2, 1, 1, 0); + graph1.setSourceNode(0); + graph1.setSinkNode(1); + final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow); + assertEquals("Sink node 1 shouldn't have output", exception.getMessage()); + } + + @Test + public void testInvalidFlow() { + final Graph<Integer> graph1 = new Graph<>(); + graph1.addEdge(0, 1, 1, 1, 1); + graph1.addEdge(0, 2, 2, 1, 2); + graph1.addEdge(1, 3, 1, 1, 1); + graph1.addEdge(2, 3, 2, 1, 0); // Missing flow from 2 to 3 + graph1.setSourceNode(0); + graph1.setSinkNode(3); + final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow); + assertEquals("Input flow for node 2 is 2 which doesn't match output flow 0", exception.getMessage()); + } + + @Test + public void testMissingSource() { + final Graph<Integer> graph1 = new Graph<>(); + graph1.addEdge(0, 1, 1, 1, 1); + graph1.addEdge(0, 2, 2, 1, 2); + graph1.addEdge(1, 3, 1, 1, 1); + graph1.addEdge(2, 3, 2, 1, 2); + graph1.setSinkNode(3); + final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow); + assertEquals("Output flow for source null is null which doesn't match input flow 3 for sink 3", + exception.getMessage()); + } + + @Test + public void testDisconnectedGraph() { + final Graph<Integer> graph1 = new Graph<>(); + graph1.addEdge(0, 1, 1, 1, 1); + graph1.addEdge(2, 3, 2, 1, 2); + graph1.setSourceNode(0); + graph1.setSinkNode(1); + final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow); + assertEquals("Input flow for node 3 is 2 which doesn't match output flow null", + exception.getMessage()); + } + + @Test + public void testDisconnectedGraphCrossSourceSink() { + final Graph<Integer> graph1 = new Graph<>(); + graph1.addEdge(0, 1, 1, 1, 1); + graph1.addEdge(2, 3, 2, 1, 2); + graph1.setSourceNode(0); + graph1.setSinkNode(3); + final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow); + assertEquals("Input flow for node 1 is 1 which doesn't match output flow null", + exception.getMessage()); + } + + @Test + public void testJustSourceSink() { + final Graph<Integer> graph1 = new Graph<>(); + graph1.addEdge(0, 1, 1, 1, 1); + graph1.setSourceNode(0); + graph1.setSinkNode(1); + graph1.solveMinCostFlow(); + assertEquals(1, graph1.totalCost()); + } + + @Test + public void testMinCostFlow() { + // Original graph, flow from 0 to 1 and 2 to 3 + Map<Integer, Graph<Integer>.Edge> edges = graph.edges(0); + Graph<Integer>.Edge edge = edges.get(1); + assertEquals(1, edge.flow); + assertEquals(0, edge.residualFlow); + + edge = edges.get(3); + assertEquals(0, edge.flow); + assertEquals(1, edge.residualFlow); + + edges = graph.edges(2); + edge = edges.get(3); + assertEquals(1, edge.flow); + assertEquals(0, edge.residualFlow); + + edge = edges.get(1); + assertEquals(0, edge.flow); + assertEquals(1, edge.residualFlow); + + assertEquals(5, graph.totalCost()); + + graph.solveMinCostFlow(); + + assertEquals(2, graph.totalCost()); + + edges = graph.edges(0); + assertEquals(2, edges.size()); + + // No flow from 0 to 1 + edge = edges.get(1); + assertEquals(0, edge.flow); + assertEquals(1, edge.residualFlow); + + // Flow from 0 to 3 now + edge = edges.get(3); + assertEquals(1, edge.flow); + assertEquals(0, edge.residualFlow); + + edges = graph.edges(2); + assertEquals(2, edges.size()); + + // No flow from 2 to 3 + edge = edges.get(3); + assertEquals(0, edge.flow); + assertEquals(1, edge.residualFlow); + + // Flow from 2 to 1 now + edge = edges.get(1); + assertEquals(1, edge.flow); + assertEquals(0, edge.residualFlow); + } Review Comment: Seems we do not check all conditions? Flow from 0->1 is shifted to 0->3, right? ########## streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/Graph.java: ########## @@ -0,0 +1,367 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import java.util.HashMap; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.Set; +import java.util.SortedMap; +import java.util.SortedSet; +import java.util.TreeMap; +import java.util.TreeSet; +import java.util.stream.Collectors; + +public class Graph<V extends Comparable<V>> { + public class Edge { + final V destination; + final int capacity; + final int cost; + int residualFlow; + int flow; + Edge counterEdge; + boolean forwardEdge; + + public Edge(final V destination, final int capacity, final int cost, final int residualFlow, final int flow) { + this(destination, capacity, cost, residualFlow, flow, true); + } + + public Edge(final V destination, final int capacity, final int cost, final int residualFlow, final int flow, + final boolean forwardEdge) { + Objects.requireNonNull(destination); + if (capacity < 0) { + throw new IllegalArgumentException("Edge capacity cannot be negative"); + } + if (flow > capacity) { + throw new IllegalArgumentException(String.format("Edge flow %d cannot exceed capacity %d", + flow, capacity)); + } + + this.destination = destination; + this.capacity = capacity; + this.cost = cost; + this.residualFlow = residualFlow; + this.flow = flow; + this.forwardEdge = forwardEdge; + } + + @Override + public boolean equals(final Object other) { + if (this == other) { + return true; + } + if (other == null || other.getClass() != getClass()) { + return false; + } + + final Graph<?>.Edge otherEdge = (Graph<?>.Edge) other; + + return destination.equals(otherEdge.destination) && capacity == otherEdge.capacity + && cost == otherEdge.cost && residualFlow == otherEdge.residualFlow && flow == otherEdge.flow + && forwardEdge == otherEdge.forwardEdge; + } + + @Override + public int hashCode() { + return Objects.hash(destination, capacity, cost, residualFlow, flow, forwardEdge); + } + + @Override + public String toString() { + return "{destination= " + destination + ", capacity=" + capacity + ", cost=" + cost + + ", residualFlow=" + residualFlow + ", flow=" + flow + ", forwardEdge=" + forwardEdge; + } + } + + private final SortedMap<V, SortedMap<V, Edge>> adjList = new TreeMap<>(); + private final SortedSet<V> nodes = new TreeSet<>(); + private final boolean isResidualGraph; + private V sourceNode, sinkNode; + + public Graph() { + this(false); + } + + private Graph(final boolean isResidualGraph) { + this.isResidualGraph = isResidualGraph; + } + + public void addEdge(final V u, final V v, final int capacity, final int cost, final int flow) { + addEdge(u, new Edge(v, capacity, cost, capacity - flow, flow)); + } + + public Set<V> nodes() { + return nodes; + } + + public Map<V, Edge> edges(final V node) { + return adjList.get(node); + } + + public boolean isResidualGraph() { + return isResidualGraph; + } + + public void setSourceNode(final V node) { + sourceNode = node; + } + + public void setSinkNode(final V node) { + sinkNode = node; + } + + public int totalCost() { + int totalCost = 0; + for (final Map.Entry<V, SortedMap<V, Edge>> nodeEdges : adjList.entrySet()) { + final SortedMap<V, Edge> edges = nodeEdges.getValue(); + for (final Entry<V, Edge> nodeEdge : edges.entrySet()) { Review Comment: nit: we could just iterate over the `valueSet` ? ########## streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/Graph.java: ########## @@ -0,0 +1,377 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.Set; +import java.util.SortedMap; +import java.util.SortedSet; +import java.util.TreeMap; +import java.util.TreeSet; +import java.util.stream.Collectors; + +public class Graph<V extends Comparable<V>> { + public class Edge { + final V destination; + final int capacity; + final int cost; + int residualFlow; Review Comment: Do we actually need this field? If I read the code right, for a forward edge, it's always `capacity - flow` and for a backward edge it's always `0`. So it seem redundant (and potentially error prone to store it expliclity)? -- Instead we could have a `residualFlow()` method that compute it on-the-fly (we could also simplify the update logic when modifying flow as we only need to update the `flow` itself)? Or do I read the update logic inside `cancelNegativeCycle` incorrectly and those properties are not an invariant? ########## streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/Graph.java: ########## @@ -0,0 +1,367 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import java.util.HashMap; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.Set; +import java.util.SortedMap; +import java.util.SortedSet; +import java.util.TreeMap; +import java.util.TreeSet; +import java.util.stream.Collectors; + +public class Graph<V extends Comparable<V>> { + public class Edge { + final V destination; + final int capacity; + final int cost; + int residualFlow; + int flow; + Edge counterEdge; + boolean forwardEdge; + + public Edge(final V destination, final int capacity, final int cost, final int residualFlow, final int flow) { + this(destination, capacity, cost, residualFlow, flow, true); + } + + public Edge(final V destination, final int capacity, final int cost, final int residualFlow, final int flow, + final boolean forwardEdge) { + Objects.requireNonNull(destination); + if (capacity < 0) { + throw new IllegalArgumentException("Edge capacity cannot be negative"); + } + if (flow > capacity) { + throw new IllegalArgumentException(String.format("Edge flow %d cannot exceed capacity %d", + flow, capacity)); + } + + this.destination = destination; + this.capacity = capacity; + this.cost = cost; + this.residualFlow = residualFlow; + this.flow = flow; + this.forwardEdge = forwardEdge; + } + + @Override + public boolean equals(final Object other) { + if (this == other) { + return true; + } + if (other == null || other.getClass() != getClass()) { + return false; + } + + final Graph<?>.Edge otherEdge = (Graph<?>.Edge) other; + + return destination.equals(otherEdge.destination) && capacity == otherEdge.capacity + && cost == otherEdge.cost && residualFlow == otherEdge.residualFlow && flow == otherEdge.flow + && forwardEdge == otherEdge.forwardEdge; + } + + @Override + public int hashCode() { + return Objects.hash(destination, capacity, cost, residualFlow, flow, forwardEdge); + } + + @Override + public String toString() { + return "{destination= " + destination + ", capacity=" + capacity + ", cost=" + cost Review Comment: `return "Edge: {...}"` ? ########## streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/GraphTest.java: ########## @@ -0,0 +1,414 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.junit.Before; +import org.junit.Test; + +public class GraphTest { + private Graph<Integer> graph; + + @Before + public void setUp() { + /* + * Node 0 and 2 are both connected to node 1 and 3. There's a flow of 1 unit from 0 to 1 and 2 to + * 3. The total cost in this case is 5. Min cost should be 2 by flowing 1 unit from 0 to 3 and 2 + * to 1 + */ + graph = new Graph<>(); + graph.addEdge(0, 1, 1, 3, 1); + graph.addEdge(0, 3, 1, 1, 0); + graph.addEdge(2, 1, 1, 1, 0); + graph.addEdge(2, 3, 1, 2, 1); + graph.addEdge(4, 0, 1, 0, 1); + graph.addEdge(4, 2, 1, 0, 1); + graph.addEdge(1, 5, 1, 0, 1); + graph.addEdge(3, 5, 1, 0, 1); + graph.setSourceNode(4); + graph.setSinkNode(5); + } + + @Test + public void testBasic() { + final Set<Integer> nodes = graph.nodes(); + assertEquals(6, nodes.size()); + assertThat(nodes, contains(0, 1, 2, 3, 4, 5)); + + Map<Integer, Graph<Integer>.Edge> edges = graph.edges(0); + assertEquals(2, edges.size()); + assertEquals(getEdge(1, 1, 3, 0, 1), edges.get(1)); + assertEquals(getEdge(3, 1, 1, 1, 0), edges.get(3)); + + edges = graph.edges(2); + assertEquals(2, edges.size()); + assertEquals(getEdge(1, 1, 1, 1, 0), edges.get(1)); + assertEquals(getEdge(3, 1, 2, 0, 1), edges.get(3)); + + edges = graph.edges(1); + assertEquals(1, edges.size()); + assertEquals(getEdge(5, 1, 0, 0, 1), edges.get(5)); + + edges = graph.edges(3); + assertEquals(1, edges.size()); + assertEquals(getEdge(5, 1, 0, 0, 1), edges.get(5)); + + edges = graph.edges(4); + assertEquals(2, edges.size()); + assertEquals(getEdge(0, 1, 0, 0, 1), edges.get(0)); + assertEquals(getEdge(2, 1, 0, 0, 1), edges.get(2)); + + edges = graph.edges(5); + assertNull(edges); + + assertFalse(graph.isResidualGraph()); + } + + @Test + public void testResidualGraph() { + final Graph<Integer> residualGraph = graph.residualGraph(); + final Graph<Integer> residualGraph1 = residualGraph.residualGraph(); + assertSame(residualGraph1, residualGraph); + + final Set<Integer> nodes = residualGraph.nodes(); + assertEquals(6, nodes.size()); + assertThat(nodes, contains(0, 1, 2, 3, 4, 5)); + + Map<Integer, Graph<Integer>.Edge> edges = residualGraph.edges(0); + assertEquals(3, edges.size()); + assertEquals(getEdge(1, 1, 3, 0, 1), edges.get(1)); + assertEquals(getEdge(3, 1, 1, 1, 0), edges.get(3)); + assertEquals(getEdge(4, 1, 0, 1, 0, false), edges.get(4)); + + edges = residualGraph.edges(2); + assertEquals(3, edges.size()); + assertEquals(getEdge(1, 1, 1, 1, 0), edges.get(1)); + assertEquals(getEdge(3, 1, 2, 0, 1), edges.get(3)); + assertEquals(getEdge(4, 1, 0, 1, 0, false), edges.get(4)); + + edges = residualGraph.edges(1); + assertEquals(3, edges.size()); + assertEquals(getEdge(0, 1, -3, 1, 0, false), edges.get(0)); + assertEquals(getEdge(2, 1, -1, 0, 0, false), edges.get(2)); + assertEquals(getEdge(5, 1, 0, 0, 1), edges.get(5)); + + edges = residualGraph.edges(3); + assertEquals(3, edges.size()); + assertEquals(getEdge(0, 1, -1, 0, 0, false), edges.get(0)); + assertEquals(getEdge(2, 1, -2, 1, 0, false), edges.get(2)); + assertEquals(getEdge(5, 1, 0, 0, 1), edges.get(5)); + + assertTrue(residualGraph.isResidualGraph()); + } + + @Test + public void testInvalidOperation() { + final Graph<Integer> graph1 = new Graph<>(); + Exception exception = assertThrows(IllegalArgumentException.class, () -> graph1.addEdge(0, 1, -1, 0, 0)); + assertEquals("Edge capacity cannot be negative", exception.getMessage()); + + exception = assertThrows(IllegalArgumentException.class, () -> graph1.addEdge(0, 1, 1, 0, 2)); + assertEquals("Edge flow 2 cannot exceed capacity 1", exception.getMessage()); + + graph1.addEdge(0, 1, 1, 1, 1); + exception = assertThrows(IllegalArgumentException.class, () -> graph1.addEdge(1, 0, 1, 0, 0)); + assertEquals("There is already an edge from 0 to 1. Can not add an edge from 1 to 0 since " + + "there will create a cycle between two nodes", exception.getMessage()); + + final Graph<Integer> residualGraph = graph1.residualGraph(); + exception = assertThrows(IllegalStateException.class, residualGraph::solveMinCostFlow); + assertEquals("Should not be residual graph to solve min cost flow", exception.getMessage()); + + } + + @Test + public void testInvalidSource() { + final Graph<Integer> graph1 = new Graph<>(); + graph1.addEdge(0, 1, 1, 1, 0); + graph1.addEdge(1, 2, 1, 1, 0); + graph1.setSourceNode(1); + graph1.setSinkNode(2); + final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow); + assertEquals("Source node 1 shouldn't have input 0", exception.getMessage()); + } + + @Test + public void testInvalidSink() { + final Graph<Integer> graph1 = new Graph<>(); + graph1.addEdge(0, 1, 1, 1, 0); + graph1.addEdge(1, 2, 1, 1, 0); + graph1.setSourceNode(0); + graph1.setSinkNode(1); + final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow); + assertEquals("Sink node 1 shouldn't have output", exception.getMessage()); + } + + @Test + public void testInvalidFlow() { + final Graph<Integer> graph1 = new Graph<>(); + graph1.addEdge(0, 1, 1, 1, 1); + graph1.addEdge(0, 2, 2, 1, 2); + graph1.addEdge(1, 3, 1, 1, 1); + graph1.addEdge(2, 3, 2, 1, 0); // Missing flow from 2 to 3 + graph1.setSourceNode(0); + graph1.setSinkNode(3); + final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow); + assertEquals("Input flow for node 2 is 2 which doesn't match output flow 0", exception.getMessage()); + } + + @Test + public void testMissingSource() { + final Graph<Integer> graph1 = new Graph<>(); + graph1.addEdge(0, 1, 1, 1, 1); + graph1.addEdge(0, 2, 2, 1, 2); + graph1.addEdge(1, 3, 1, 1, 1); + graph1.addEdge(2, 3, 2, 1, 2); + graph1.setSinkNode(3); + final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow); + assertEquals("Output flow for source null is null which doesn't match input flow 3 for sink 3", + exception.getMessage()); + } + + @Test + public void testDisconnectedGraph() { + final Graph<Integer> graph1 = new Graph<>(); + graph1.addEdge(0, 1, 1, 1, 1); + graph1.addEdge(2, 3, 2, 1, 2); + graph1.setSourceNode(0); + graph1.setSinkNode(1); + final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow); + assertEquals("Input flow for node 3 is 2 which doesn't match output flow null", + exception.getMessage()); + } + + @Test + public void testDisconnectedGraphCrossSourceSink() { + final Graph<Integer> graph1 = new Graph<>(); + graph1.addEdge(0, 1, 1, 1, 1); + graph1.addEdge(2, 3, 2, 1, 2); + graph1.setSourceNode(0); + graph1.setSinkNode(3); + final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow); + assertEquals("Input flow for node 1 is 1 which doesn't match output flow null", + exception.getMessage()); + } + + @Test + public void testJustSourceSink() { + final Graph<Integer> graph1 = new Graph<>(); + graph1.addEdge(0, 1, 1, 1, 1); + graph1.setSourceNode(0); + graph1.setSinkNode(1); + graph1.solveMinCostFlow(); + assertEquals(1, graph1.totalCost()); + } + + @Test + public void testMinCostFlow() { + // Original graph, flow from 0 to 1 and 2 to 3 + Map<Integer, Graph<Integer>.Edge> edges = graph.edges(0); + Graph<Integer>.Edge edge = edges.get(1); + assertEquals(1, edge.flow); + assertEquals(0, edge.residualFlow); + + edge = edges.get(3); + assertEquals(0, edge.flow); + assertEquals(1, edge.residualFlow); + + edges = graph.edges(2); + edge = edges.get(3); + assertEquals(1, edge.flow); + assertEquals(0, edge.residualFlow); + + edge = edges.get(1); + assertEquals(0, edge.flow); + assertEquals(1, edge.residualFlow); + + assertEquals(5, graph.totalCost()); + + graph.solveMinCostFlow(); + + assertEquals(2, graph.totalCost()); + + edges = graph.edges(0); + assertEquals(2, edges.size()); + + // No flow from 0 to 1 + edge = edges.get(1); + assertEquals(0, edge.flow); + assertEquals(1, edge.residualFlow); + + // Flow from 0 to 3 now + edge = edges.get(3); + assertEquals(1, edge.flow); + assertEquals(0, edge.residualFlow); + + edges = graph.edges(2); + assertEquals(2, edges.size()); + + // No flow from 2 to 3 + edge = edges.get(3); + assertEquals(0, edge.flow); + assertEquals(1, edge.residualFlow); + + // Flow from 2 to 1 now + edge = edges.get(1); + assertEquals(1, edge.flow); + assertEquals(0, edge.residualFlow); + } + + @Test + public void testMinCostDetectNodeNotInNegativeCycle() { + final Graph<Integer> graph1 = new Graph<>(); + + graph1.addEdge(5, 0, 1, 0, 1); + graph1.addEdge(5, 1, 1, 0, 1); + + graph1.addEdge(0, 2, 1, 1, 0); + graph1.addEdge(0, 3, 1, 1, 0); + graph1.addEdge(0, 4, 1, 10, 1); + + graph1.addEdge(1, 2, 1, 1, 0); + graph1.addEdge(1, 3, 1, 10, 1); + graph1.addEdge(1, 4, 1, 1, 0); + + graph1.addEdge(2, 6, 0, 0, 0); + graph1.addEdge(3, 6, 1, 0, 1); + graph1.addEdge(4, 6, 1, 0, 1); + + graph1.setSourceNode(5); + graph1.setSinkNode(6); Review Comment: As above ########## streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/Graph.java: ########## @@ -0,0 +1,367 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import java.util.HashMap; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.Set; +import java.util.SortedMap; +import java.util.SortedSet; +import java.util.TreeMap; +import java.util.TreeSet; +import java.util.stream.Collectors; + +public class Graph<V extends Comparable<V>> { + public class Edge { + final V destination; + final int capacity; + final int cost; + int residualFlow; + int flow; + Edge counterEdge; + boolean forwardEdge; + + public Edge(final V destination, final int capacity, final int cost, final int residualFlow, final int flow) { + this(destination, capacity, cost, residualFlow, flow, true); + } + + public Edge(final V destination, final int capacity, final int cost, final int residualFlow, final int flow, + final boolean forwardEdge) { + Objects.requireNonNull(destination); + if (capacity < 0) { + throw new IllegalArgumentException("Edge capacity cannot be negative"); + } + if (flow > capacity) { + throw new IllegalArgumentException(String.format("Edge flow %d cannot exceed capacity %d", + flow, capacity)); + } + + this.destination = destination; + this.capacity = capacity; + this.cost = cost; + this.residualFlow = residualFlow; + this.flow = flow; + this.forwardEdge = forwardEdge; + } + + @Override + public boolean equals(final Object other) { + if (this == other) { + return true; + } + if (other == null || other.getClass() != getClass()) { + return false; + } + + final Graph<?>.Edge otherEdge = (Graph<?>.Edge) other; + + return destination.equals(otherEdge.destination) && capacity == otherEdge.capacity Review Comment: nit formatting: ``` return destination.equals(otherEdge.destination) && capacity == otherEdge.capacity && cost == otherEdge.cost && residualFlow == otherEdge.residualFlow && flow == otherEdge.flow && forwardEdge == otherEdge.forwardEdge; ``` ########## streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/Graph.java: ########## @@ -0,0 +1,367 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import java.util.HashMap; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.Set; +import java.util.SortedMap; +import java.util.SortedSet; +import java.util.TreeMap; +import java.util.TreeSet; +import java.util.stream.Collectors; + +public class Graph<V extends Comparable<V>> { + public class Edge { + final V destination; + final int capacity; + final int cost; + int residualFlow; + int flow; + Edge counterEdge; + boolean forwardEdge; + + public Edge(final V destination, final int capacity, final int cost, final int residualFlow, final int flow) { + this(destination, capacity, cost, residualFlow, flow, true); + } + + public Edge(final V destination, final int capacity, final int cost, final int residualFlow, final int flow, + final boolean forwardEdge) { + Objects.requireNonNull(destination); + if (capacity < 0) { + throw new IllegalArgumentException("Edge capacity cannot be negative"); + } + if (flow > capacity) { + throw new IllegalArgumentException(String.format("Edge flow %d cannot exceed capacity %d", + flow, capacity)); + } + + this.destination = destination; + this.capacity = capacity; + this.cost = cost; + this.residualFlow = residualFlow; + this.flow = flow; + this.forwardEdge = forwardEdge; + } + + @Override + public boolean equals(final Object other) { + if (this == other) { + return true; + } + if (other == null || other.getClass() != getClass()) { + return false; + } + + final Graph<?>.Edge otherEdge = (Graph<?>.Edge) other; + + return destination.equals(otherEdge.destination) && capacity == otherEdge.capacity + && cost == otherEdge.cost && residualFlow == otherEdge.residualFlow && flow == otherEdge.flow + && forwardEdge == otherEdge.forwardEdge; + } + + @Override + public int hashCode() { + return Objects.hash(destination, capacity, cost, residualFlow, flow, forwardEdge); + } + + @Override + public String toString() { + return "{destination= " + destination + ", capacity=" + capacity + ", cost=" + cost + + ", residualFlow=" + residualFlow + ", flow=" + flow + ", forwardEdge=" + forwardEdge; Review Comment: missing `}`. Should we also switch to one line per parameter we print to simplify reading? ########## streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/GraphTest.java: ########## @@ -0,0 +1,414 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.junit.Before; +import org.junit.Test; + +public class GraphTest { + private Graph<Integer> graph; + + @Before + public void setUp() { + /* + * Node 0 and 2 are both connected to node 1 and 3. There's a flow of 1 unit from 0 to 1 and 2 to + * 3. The total cost in this case is 5. Min cost should be 2 by flowing 1 unit from 0 to 3 and 2 + * to 1 + */ + graph = new Graph<>(); + graph.addEdge(0, 1, 1, 3, 1); + graph.addEdge(0, 3, 1, 1, 0); + graph.addEdge(2, 1, 1, 1, 0); + graph.addEdge(2, 3, 1, 2, 1); + graph.addEdge(4, 0, 1, 0, 1); + graph.addEdge(4, 2, 1, 0, 1); + graph.addEdge(1, 5, 1, 0, 1); + graph.addEdge(3, 5, 1, 0, 1); + graph.setSourceNode(4); Review Comment: Would it be simpler to use `0` as source? (At least for my mind it's easier to follow what going on, if number are "ordered") Or to avoid a lot of re-writing, name the source -1 (and the sink 99) so both a clearly different, and we don't need to update too much code. ########## streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/GraphTest.java: ########## @@ -0,0 +1,414 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.junit.Before; +import org.junit.Test; + +public class GraphTest { + private Graph<Integer> graph; + + @Before + public void setUp() { + /* + * Node 0 and 2 are both connected to node 1 and 3. There's a flow of 1 unit from 0 to 1 and 2 to + * 3. The total cost in this case is 5. Min cost should be 2 by flowing 1 unit from 0 to 3 and 2 + * to 1 + */ + graph = new Graph<>(); + graph.addEdge(0, 1, 1, 3, 1); + graph.addEdge(0, 3, 1, 1, 0); + graph.addEdge(2, 1, 1, 1, 0); + graph.addEdge(2, 3, 1, 2, 1); + graph.addEdge(4, 0, 1, 0, 1); + graph.addEdge(4, 2, 1, 0, 1); + graph.addEdge(1, 5, 1, 0, 1); + graph.addEdge(3, 5, 1, 0, 1); + graph.setSourceNode(4); + graph.setSinkNode(5); + } + + @Test + public void testBasic() { + final Set<Integer> nodes = graph.nodes(); + assertEquals(6, nodes.size()); + assertThat(nodes, contains(0, 1, 2, 3, 4, 5)); + + Map<Integer, Graph<Integer>.Edge> edges = graph.edges(0); + assertEquals(2, edges.size()); + assertEquals(getEdge(1, 1, 3, 0, 1), edges.get(1)); + assertEquals(getEdge(3, 1, 1, 1, 0), edges.get(3)); + + edges = graph.edges(2); + assertEquals(2, edges.size()); + assertEquals(getEdge(1, 1, 1, 1, 0), edges.get(1)); + assertEquals(getEdge(3, 1, 2, 0, 1), edges.get(3)); + + edges = graph.edges(1); + assertEquals(1, edges.size()); + assertEquals(getEdge(5, 1, 0, 0, 1), edges.get(5)); + + edges = graph.edges(3); + assertEquals(1, edges.size()); + assertEquals(getEdge(5, 1, 0, 0, 1), edges.get(5)); + + edges = graph.edges(4); + assertEquals(2, edges.size()); + assertEquals(getEdge(0, 1, 0, 0, 1), edges.get(0)); + assertEquals(getEdge(2, 1, 0, 0, 1), edges.get(2)); + + edges = graph.edges(5); + assertNull(edges); + + assertFalse(graph.isResidualGraph()); + } + + @Test + public void testResidualGraph() { + final Graph<Integer> residualGraph = graph.residualGraph(); + final Graph<Integer> residualGraph1 = residualGraph.residualGraph(); + assertSame(residualGraph1, residualGraph); + + final Set<Integer> nodes = residualGraph.nodes(); + assertEquals(6, nodes.size()); + assertThat(nodes, contains(0, 1, 2, 3, 4, 5)); + + Map<Integer, Graph<Integer>.Edge> edges = residualGraph.edges(0); + assertEquals(3, edges.size()); + assertEquals(getEdge(1, 1, 3, 0, 1), edges.get(1)); + assertEquals(getEdge(3, 1, 1, 1, 0), edges.get(3)); + assertEquals(getEdge(4, 1, 0, 1, 0, false), edges.get(4)); + + edges = residualGraph.edges(2); + assertEquals(3, edges.size()); + assertEquals(getEdge(1, 1, 1, 1, 0), edges.get(1)); + assertEquals(getEdge(3, 1, 2, 0, 1), edges.get(3)); + assertEquals(getEdge(4, 1, 0, 1, 0, false), edges.get(4)); + + edges = residualGraph.edges(1); + assertEquals(3, edges.size()); + assertEquals(getEdge(0, 1, -3, 1, 0, false), edges.get(0)); + assertEquals(getEdge(2, 1, -1, 0, 0, false), edges.get(2)); + assertEquals(getEdge(5, 1, 0, 0, 1), edges.get(5)); + + edges = residualGraph.edges(3); + assertEquals(3, edges.size()); + assertEquals(getEdge(0, 1, -1, 0, 0, false), edges.get(0)); + assertEquals(getEdge(2, 1, -2, 1, 0, false), edges.get(2)); + assertEquals(getEdge(5, 1, 0, 0, 1), edges.get(5)); + + assertTrue(residualGraph.isResidualGraph()); + } + + @Test + public void testInvalidOperation() { + final Graph<Integer> graph1 = new Graph<>(); + Exception exception = assertThrows(IllegalArgumentException.class, () -> graph1.addEdge(0, 1, -1, 0, 0)); + assertEquals("Edge capacity cannot be negative", exception.getMessage()); + + exception = assertThrows(IllegalArgumentException.class, () -> graph1.addEdge(0, 1, 1, 0, 2)); + assertEquals("Edge flow 2 cannot exceed capacity 1", exception.getMessage()); + + graph1.addEdge(0, 1, 1, 1, 1); + exception = assertThrows(IllegalArgumentException.class, () -> graph1.addEdge(1, 0, 1, 0, 0)); + assertEquals("There is already an edge from 0 to 1. Can not add an edge from 1 to 0 since " + + "there will create a cycle between two nodes", exception.getMessage()); + + final Graph<Integer> residualGraph = graph1.residualGraph(); + exception = assertThrows(IllegalStateException.class, residualGraph::solveMinCostFlow); + assertEquals("Should not be residual graph to solve min cost flow", exception.getMessage()); + + } + + @Test + public void testInvalidSource() { + final Graph<Integer> graph1 = new Graph<>(); + graph1.addEdge(0, 1, 1, 1, 0); + graph1.addEdge(1, 2, 1, 1, 0); + graph1.setSourceNode(1); + graph1.setSinkNode(2); + final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow); + assertEquals("Source node 1 shouldn't have input 0", exception.getMessage()); + } + + @Test + public void testInvalidSink() { + final Graph<Integer> graph1 = new Graph<>(); + graph1.addEdge(0, 1, 1, 1, 0); + graph1.addEdge(1, 2, 1, 1, 0); + graph1.setSourceNode(0); + graph1.setSinkNode(1); + final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow); + assertEquals("Sink node 1 shouldn't have output", exception.getMessage()); + } + + @Test + public void testInvalidFlow() { + final Graph<Integer> graph1 = new Graph<>(); + graph1.addEdge(0, 1, 1, 1, 1); + graph1.addEdge(0, 2, 2, 1, 2); + graph1.addEdge(1, 3, 1, 1, 1); + graph1.addEdge(2, 3, 2, 1, 0); // Missing flow from 2 to 3 + graph1.setSourceNode(0); + graph1.setSinkNode(3); + final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow); + assertEquals("Input flow for node 2 is 2 which doesn't match output flow 0", exception.getMessage()); + } + + @Test + public void testMissingSource() { + final Graph<Integer> graph1 = new Graph<>(); + graph1.addEdge(0, 1, 1, 1, 1); + graph1.addEdge(0, 2, 2, 1, 2); + graph1.addEdge(1, 3, 1, 1, 1); + graph1.addEdge(2, 3, 2, 1, 2); + graph1.setSinkNode(3); + final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow); + assertEquals("Output flow for source null is null which doesn't match input flow 3 for sink 3", + exception.getMessage()); + } + + @Test + public void testDisconnectedGraph() { + final Graph<Integer> graph1 = new Graph<>(); + graph1.addEdge(0, 1, 1, 1, 1); + graph1.addEdge(2, 3, 2, 1, 2); + graph1.setSourceNode(0); + graph1.setSinkNode(1); + final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow); + assertEquals("Input flow for node 3 is 2 which doesn't match output flow null", + exception.getMessage()); + } + + @Test + public void testDisconnectedGraphCrossSourceSink() { + final Graph<Integer> graph1 = new Graph<>(); + graph1.addEdge(0, 1, 1, 1, 1); + graph1.addEdge(2, 3, 2, 1, 2); + graph1.setSourceNode(0); + graph1.setSinkNode(3); + final Exception exception = assertThrows(IllegalStateException.class, graph1::solveMinCostFlow); + assertEquals("Input flow for node 1 is 1 which doesn't match output flow null", + exception.getMessage()); + } + + @Test + public void testJustSourceSink() { + final Graph<Integer> graph1 = new Graph<>(); + graph1.addEdge(0, 1, 1, 1, 1); + graph1.setSourceNode(0); + graph1.setSinkNode(1); + graph1.solveMinCostFlow(); + assertEquals(1, graph1.totalCost()); + } + + @Test + public void testMinCostFlow() { + // Original graph, flow from 0 to 1 and 2 to 3 + Map<Integer, Graph<Integer>.Edge> edges = graph.edges(0); + Graph<Integer>.Edge edge = edges.get(1); + assertEquals(1, edge.flow); + assertEquals(0, edge.residualFlow); + + edge = edges.get(3); + assertEquals(0, edge.flow); + assertEquals(1, edge.residualFlow); + + edges = graph.edges(2); + edge = edges.get(3); + assertEquals(1, edge.flow); + assertEquals(0, edge.residualFlow); + + edge = edges.get(1); + assertEquals(0, edge.flow); + assertEquals(1, edge.residualFlow); + + assertEquals(5, graph.totalCost()); + + graph.solveMinCostFlow(); + + assertEquals(2, graph.totalCost()); + + edges = graph.edges(0); + assertEquals(2, edges.size()); + + // No flow from 0 to 1 + edge = edges.get(1); + assertEquals(0, edge.flow); + assertEquals(1, edge.residualFlow); + + // Flow from 0 to 3 now + edge = edges.get(3); + assertEquals(1, edge.flow); + assertEquals(0, edge.residualFlow); + + edges = graph.edges(2); + assertEquals(2, edges.size()); + + // No flow from 2 to 3 + edge = edges.get(3); + assertEquals(0, edge.flow); + assertEquals(1, edge.residualFlow); + + // Flow from 2 to 1 now + edge = edges.get(1); + assertEquals(1, edge.flow); + assertEquals(0, edge.residualFlow); + } + + @Test + public void testMinCostDetectNodeNotInNegativeCycle() { + final Graph<Integer> graph1 = new Graph<>(); + + graph1.addEdge(5, 0, 1, 0, 1); + graph1.addEdge(5, 1, 1, 0, 1); + + graph1.addEdge(0, 2, 1, 1, 0); + graph1.addEdge(0, 3, 1, 1, 0); + graph1.addEdge(0, 4, 1, 10, 1); + + graph1.addEdge(1, 2, 1, 1, 0); + graph1.addEdge(1, 3, 1, 10, 1); + graph1.addEdge(1, 4, 1, 1, 0); + + graph1.addEdge(2, 6, 0, 0, 0); + graph1.addEdge(3, 6, 1, 0, 1); + graph1.addEdge(4, 6, 1, 0, 1); + + graph1.setSourceNode(5); + graph1.setSinkNode(6); + + assertEquals(20, graph1.totalCost()); + + // In this graph, the node we found for negative cycle is 2. However 2 isn't in the negative + // cycle itself. Negative cycle is 1 -> 4 -> 0 -> 3 -> 1 + graph1.solveMinCostFlow(); + assertEquals(2, graph1.totalCost()); + + Map<Integer, Graph<Integer>.Edge> edges = graph1.edges(5); + assertEquals(getEdge(0, 1, 0, 0, 1), edges.get(0)); + assertEquals(getEdge(1, 1, 0, 0, 1), edges.get(1)); + + edges = graph1.edges(0); + assertEquals(getEdge(2, 1, 1, 1, 0), edges.get(2)); + assertEquals(getEdge(3, 1, 1, 0, 1), edges.get(3)); + assertEquals(getEdge(4, 1, 10, 1, 0), edges.get(4)); + + edges = graph1.edges(1); + assertEquals(getEdge(2, 1, 1, 1, 0), edges.get(2)); + assertEquals(getEdge(3, 1, 10, 1, 0), edges.get(3)); + assertEquals(getEdge(4, 1, 1, 0, 1), edges.get(4)); + + edges = graph1.edges(2); + assertEquals(getEdge(6, 0, 0, 0, 0), edges.get(6)); + + edges = graph1.edges(3); + assertEquals(getEdge(6, 1, 0, 0, 1), edges.get(6)); + + edges = graph1.edges(4); + assertEquals(getEdge(6, 1, 0, 0, 1), edges.get(6)); + } + + @Test + public void testDeterministic() { + final List<TestEdge> edgeList = new ArrayList<>(); + edgeList.add(new TestEdge(0, 1, 1, 2, 1)); + edgeList.add(new TestEdge(0, 2, 1, 1, 0)); + edgeList.add(new TestEdge(0, 3, 1, 1, 0)); + edgeList.add(new TestEdge(0, 4, 1, 1, 0)); + edgeList.add(new TestEdge(1, 5, 1, 1, 1)); + edgeList.add(new TestEdge(2, 5, 1, 1, 0)); + edgeList.add(new TestEdge(3, 5, 1, 1, 0)); + edgeList.add(new TestEdge(4, 5, 1, 1, 0)); + + // Test no matter the order of adding edges, min cost flow flows from 0 to 2 and then from 2 to 5 + for (int i = 0; i < 10; i++) { Review Comment: Why do we need to test this 10 times? Given that the test runs for each PR on nightly builds, it seems sufficient to just run it once? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: jira-unsubscr...@kafka.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org