lucasbru commented on code in PR #18397:
URL: https://github.com/apache/kafka/pull/18397#discussion_r1908739762


##########
group-coordinator/src/test/java/org/apache/kafka/coordinator/group/streams/topics/CopartitionedTopicsEnforcerTest.java:
##########
@@ -0,0 +1,261 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.coordinator.group.streams.topics;
+
+import org.apache.kafka.common.requests.StreamsGroupHeartbeatResponse.Status;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.Utils;
+
+import org.junit.jupiter.api.Test;
+
+import java.util.Collections;
+import java.util.Map;
+import java.util.OptionalInt;
+import java.util.Set;
+import java.util.TreeMap;
+import java.util.function.Function;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+public class CopartitionedTopicsEnforcerTest {
+
+    private static final LogContext LOG_CONTEXT = new LogContext();
+
+    private static Function<String, OptionalInt> 
topicPartitionProvider(Map<String, Integer> topicPartitionCounts) {
+        return topic -> {
+            Integer a = topicPartitionCounts.get(topic);
+            return a == null ? OptionalInt.empty() : OptionalInt.of(a);
+        };
+    }
+
+    @Test
+    public void 
shouldThrowTopicConfigurationExceptionIfNoPartitionsFoundForCoPartitionedTopic()
 {
+        final String topic = "topic";
+        final Map<String, Integer> topicPartitionCounts = 
Collections.emptyMap();
+        final CopartitionedTopicsEnforcer enforcer =
+            new CopartitionedTopicsEnforcer(LOG_CONTEXT, 
topicPartitionProvider(topicPartitionCounts));
+
+        final TopicConfigurationException ex = 
assertThrows(TopicConfigurationException.class, () ->
+            enforcer.enforce(
+                Set.of(topic),
+                Set.of(),
+                Set.of()
+            ));
+        assertEquals(Status.MISSING_SOURCE_TOPICS, ex.status());
+        assertEquals("Following topics are missing: [topic]", ex.getMessage());
+    }
+
+    @Test
+    public void 
shouldThrowTopicConfigurationExceptionIfPartitionCountsForCoPartitionedTopicsDontMatch()
 {
+        final String firstSourceTopic = "first";
+        final String secondSourceTopic = "second";
+        final Map<String, Integer> topicPartitionCounts = 
Map.of(firstSourceTopic, 2, secondSourceTopic, 1);
+        final CopartitionedTopicsEnforcer enforcer =
+            new CopartitionedTopicsEnforcer(LOG_CONTEXT, 
topicPartitionProvider(topicPartitionCounts));
+
+        final TopicConfigurationException ex = 
assertThrows(TopicConfigurationException.class, () ->
+            enforcer.enforce(
+                Set.of(firstSourceTopic, secondSourceTopic),
+                Set.of(),
+                Set.of()
+            )
+        );
+        assertEquals(Status.INCORRECTLY_PARTITIONED_TOPICS, ex.status());
+        assertEquals("Following topics do not have the same number of 
partitions: " +
+            "[{first=2, second=1}]", ex.getMessage());
+    }
+
+
+    @Test
+    public void shouldEnforceCopartitioningOnRepartitionTopics() {
+        final String firstSourceTopic = "first";
+        final String secondSourceTopic = "second";
+        final String repartitionTopic = "repartitioned";
+        final Map<String, Integer> topicPartitionCounts = Map.of(
+            firstSourceTopic, 2,
+            secondSourceTopic, 2,
+            repartitionTopic, 10
+        );
+        final CopartitionedTopicsEnforcer enforcer =
+            new CopartitionedTopicsEnforcer(LOG_CONTEXT, 
topicPartitionProvider(topicPartitionCounts));
+
+        final Map<String, Integer> result =
+            enforcer.enforce(
+                Set.of(firstSourceTopic, secondSourceTopic, repartitionTopic),
+                Set.of(),
+                Set.of(repartitionTopic)
+            );
+
+        assertEquals(Map.of(repartitionTopic, 2), result);
+    }
+
+

Review Comment:
   Done



##########
group-coordinator/src/test/java/org/apache/kafka/coordinator/group/streams/topics/CopartitionedTopicsEnforcerTest.java:
##########
@@ -0,0 +1,261 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.coordinator.group.streams.topics;
+
+import org.apache.kafka.common.requests.StreamsGroupHeartbeatResponse.Status;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.Utils;
+
+import org.junit.jupiter.api.Test;
+
+import java.util.Collections;
+import java.util.Map;
+import java.util.OptionalInt;
+import java.util.Set;
+import java.util.TreeMap;
+import java.util.function.Function;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+public class CopartitionedTopicsEnforcerTest {
+
+    private static final LogContext LOG_CONTEXT = new LogContext();
+
+    private static Function<String, OptionalInt> 
topicPartitionProvider(Map<String, Integer> topicPartitionCounts) {
+        return topic -> {
+            Integer a = topicPartitionCounts.get(topic);
+            return a == null ? OptionalInt.empty() : OptionalInt.of(a);
+        };
+    }
+
+    @Test
+    public void 
shouldThrowTopicConfigurationExceptionIfNoPartitionsFoundForCoPartitionedTopic()
 {
+        final String topic = "topic";
+        final Map<String, Integer> topicPartitionCounts = 
Collections.emptyMap();
+        final CopartitionedTopicsEnforcer enforcer =
+            new CopartitionedTopicsEnforcer(LOG_CONTEXT, 
topicPartitionProvider(topicPartitionCounts));
+
+        final TopicConfigurationException ex = 
assertThrows(TopicConfigurationException.class, () ->
+            enforcer.enforce(
+                Set.of(topic),
+                Set.of(),
+                Set.of()
+            ));
+        assertEquals(Status.MISSING_SOURCE_TOPICS, ex.status());
+        assertEquals("Following topics are missing: [topic]", ex.getMessage());
+    }
+
+    @Test
+    public void 
shouldThrowTopicConfigurationExceptionIfPartitionCountsForCoPartitionedTopicsDontMatch()
 {
+        final String firstSourceTopic = "first";
+        final String secondSourceTopic = "second";
+        final Map<String, Integer> topicPartitionCounts = 
Map.of(firstSourceTopic, 2, secondSourceTopic, 1);
+        final CopartitionedTopicsEnforcer enforcer =
+            new CopartitionedTopicsEnforcer(LOG_CONTEXT, 
topicPartitionProvider(topicPartitionCounts));
+
+        final TopicConfigurationException ex = 
assertThrows(TopicConfigurationException.class, () ->
+            enforcer.enforce(
+                Set.of(firstSourceTopic, secondSourceTopic),
+                Set.of(),
+                Set.of()
+            )
+        );
+        assertEquals(Status.INCORRECTLY_PARTITIONED_TOPICS, ex.status());
+        assertEquals("Following topics do not have the same number of 
partitions: " +
+            "[{first=2, second=1}]", ex.getMessage());
+    }
+
+
+    @Test
+    public void shouldEnforceCopartitioningOnRepartitionTopics() {
+        final String firstSourceTopic = "first";
+        final String secondSourceTopic = "second";
+        final String repartitionTopic = "repartitioned";
+        final Map<String, Integer> topicPartitionCounts = Map.of(
+            firstSourceTopic, 2,
+            secondSourceTopic, 2,
+            repartitionTopic, 10
+        );
+        final CopartitionedTopicsEnforcer enforcer =
+            new CopartitionedTopicsEnforcer(LOG_CONTEXT, 
topicPartitionProvider(topicPartitionCounts));
+
+        final Map<String, Integer> result =
+            enforcer.enforce(
+                Set.of(firstSourceTopic, secondSourceTopic, repartitionTopic),
+                Set.of(),
+                Set.of(repartitionTopic)
+            );
+
+        assertEquals(Map.of(repartitionTopic, 2), result);
+    }
+
+
+    @Test
+    public void 
shouldSetNumPartitionsToMaximumPartitionsWhenAllTopicsAreRepartitionTopics() {
+        final String repartitionTopic1 = "repartitionTopic1";
+        final String repartitionTopic2 = "repartitionTopic2";
+        final String repartitionTopic3 = "repartitionTopic3";
+        final Map<String, Integer> topicPartitionCounts = Map.of(
+            repartitionTopic1, 1,
+            repartitionTopic2, 15,
+            repartitionTopic3, 5
+        );
+        final CopartitionedTopicsEnforcer enforcer =
+            new CopartitionedTopicsEnforcer(LOG_CONTEXT, 
topicPartitionProvider(topicPartitionCounts));
+
+        final Map<String, Integer> result = enforcer.enforce(
+            Set.of(repartitionTopic1, repartitionTopic2, repartitionTopic3),
+            Set.of(),
+            Set.of(repartitionTopic1, repartitionTopic2, repartitionTopic3)
+        );
+
+        assertEquals(Map.of(
+            repartitionTopic1, 15,
+            repartitionTopic2, 15,
+            repartitionTopic3, 15
+        ), result);
+    }
+
+    @Test
+    public void 
shouldThrowAnExceptionIfTopicInfosWithEnforcedNumOfPartitionsHaveDifferentNumOfPartitions()
 {
+        final String repartitionTopic1 = "repartitioned-1";
+        final String repartitionTopic2 = "repartitioned-2";
+        final Map<String, Integer> topicPartitionCounts = Map.of(
+            repartitionTopic1, 10,
+            repartitionTopic2, 5
+        );
+        final CopartitionedTopicsEnforcer enforcer =
+            new CopartitionedTopicsEnforcer(LOG_CONTEXT, 
topicPartitionProvider(topicPartitionCounts));
+
+        final TopicConfigurationException ex = assertThrows(
+            TopicConfigurationException.class,
+            () -> enforcer.enforce(
+                Set.of(repartitionTopic1, repartitionTopic2),
+                Set.of(repartitionTopic1, repartitionTopic2),
+                Set.of()
+            )
+        );
+
+        final TreeMap<String, Integer> sorted = new TreeMap<>(
+            Utils.mkMap(Utils.mkEntry(repartitionTopic1, 10),
+                Utils.mkEntry(repartitionTopic2, 5))
+        );
+        assertEquals(Status.INCORRECTLY_PARTITIONED_TOPICS, ex.status());
+        assertEquals(String.format(
+            "Following topics do not have the same number of partitions: " +
+                "[%s]", sorted), ex.getMessage());
+    }
+
+    @Test
+    public void 
shouldNotThrowAnExceptionWhenTopicInfosWithEnforcedNumOfPartitionsAreValid() {

Review Comment:
   Done



##########
group-coordinator/src/main/java/org/apache/kafka/coordinator/group/streams/topics/CopartitionedTopicsEnforcer.java:
##########
@@ -0,0 +1,189 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.coordinator.group.streams.topics;
+
+import org.apache.kafka.common.errors.StreamsInvalidTopologyException;
+import org.apache.kafka.common.utils.LogContext;
+
+import org.slf4j.Logger;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.OptionalInt;
+import java.util.Set;
+import java.util.TreeMap;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+
+/**
+ * This class is responsible for enforcing the number of partitions in 
copartitioned topics. For each copartition group, it checks whether
+ * the number of partitions for all repartition topics is the same, and 
enforces copartitioning for repartition topics whose number of
+ * partitions is not enforced by the topology.
+ */
+public class CopartitionedTopicsEnforcer {
+
+    private final Logger log;
+    private final Function<String, OptionalInt> topicPartitionCountProvider;
+
+    /**
+     * The constructor for the class.
+     *
+     * @param logContext                  The context for emitting log 
messages.
+     * @param topicPartitionCountProvider Returns the number of partitions for 
a given topic, representing the current state of the broker
+     *                                    as well as any partition number 
decisions that have already been made. In particular, we expect
+     *                                    the number of partitions for all 
repartition topics defined, even if they do not exist in the
+     *                                    broker yet.
+     */
+    public CopartitionedTopicsEnforcer(final LogContext logContext,
+                                       final Function<String, OptionalInt> 
topicPartitionCountProvider) {
+        this.log = logContext.logger(getClass());
+        this.topicPartitionCountProvider = topicPartitionCountProvider;
+    }
+
+    /**
+     * Enforces the number of partitions for copartitioned topics.
+     *
+     * @param copartitionedTopics          The set of copartitioned topics 
(external source topics and repartition topics).
+     * @param fixedRepartitionTopics       The set of repartition topics whose 
partition count is fixed by the topology.
+     * @param flexibleRepartitionTopics    The set of repartition topics whose 
partition count is flexible, and can be changed.
+     *
+     * @throws TopicConfigurationException If source topics are missing, or 
there are topics in copartitionTopics that are not copartitioned
+     *                                     according to 
topicPartitionCountProvider are not co-partitioned.

Review Comment:
   Fixed



##########
group-coordinator/src/test/java/org/apache/kafka/coordinator/group/streams/topics/CopartitionedTopicsEnforcerTest.java:
##########
@@ -0,0 +1,261 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.coordinator.group.streams.topics;
+
+import org.apache.kafka.common.requests.StreamsGroupHeartbeatResponse.Status;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.Utils;
+
+import org.junit.jupiter.api.Test;
+
+import java.util.Collections;
+import java.util.Map;
+import java.util.OptionalInt;
+import java.util.Set;
+import java.util.TreeMap;
+import java.util.function.Function;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+public class CopartitionedTopicsEnforcerTest {
+
+    private static final LogContext LOG_CONTEXT = new LogContext();
+
+    private static Function<String, OptionalInt> 
topicPartitionProvider(Map<String, Integer> topicPartitionCounts) {
+        return topic -> {
+            Integer a = topicPartitionCounts.get(topic);
+            return a == null ? OptionalInt.empty() : OptionalInt.of(a);
+        };
+    }
+
+    @Test
+    public void 
shouldThrowTopicConfigurationExceptionIfNoPartitionsFoundForCoPartitionedTopic()
 {
+        final String topic = "topic";
+        final Map<String, Integer> topicPartitionCounts = 
Collections.emptyMap();
+        final CopartitionedTopicsEnforcer enforcer =
+            new CopartitionedTopicsEnforcer(LOG_CONTEXT, 
topicPartitionProvider(topicPartitionCounts));
+
+        final TopicConfigurationException ex = 
assertThrows(TopicConfigurationException.class, () ->
+            enforcer.enforce(
+                Set.of(topic),
+                Set.of(),
+                Set.of()
+            ));
+        assertEquals(Status.MISSING_SOURCE_TOPICS, ex.status());
+        assertEquals("Following topics are missing: [topic]", ex.getMessage());
+    }
+
+    @Test
+    public void 
shouldThrowTopicConfigurationExceptionIfPartitionCountsForCoPartitionedTopicsDontMatch()
 {
+        final String firstSourceTopic = "first";
+        final String secondSourceTopic = "second";

Review Comment:
   Interesting that you think this is better. Done



##########
group-coordinator/src/main/java/org/apache/kafka/coordinator/group/streams/topics/CopartitionedTopicsEnforcer.java:
##########
@@ -0,0 +1,189 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.coordinator.group.streams.topics;
+
+import org.apache.kafka.common.errors.StreamsInvalidTopologyException;
+import org.apache.kafka.common.utils.LogContext;
+
+import org.slf4j.Logger;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.OptionalInt;
+import java.util.Set;
+import java.util.TreeMap;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+
+/**
+ * This class is responsible for enforcing the number of partitions in 
copartitioned topics. For each copartition group, it checks whether
+ * the number of partitions for all repartition topics is the same, and 
enforces copartitioning for repartition topics whose number of
+ * partitions is not enforced by the topology.
+ */
+public class CopartitionedTopicsEnforcer {
+
+    private final Logger log;
+    private final Function<String, OptionalInt> topicPartitionCountProvider;
+
+    /**
+     * The constructor for the class.
+     *
+     * @param logContext                  The context for emitting log 
messages.
+     * @param topicPartitionCountProvider Returns the number of partitions for 
a given topic, representing the current state of the broker
+     *                                    as well as any partition number 
decisions that have already been made. In particular, we expect
+     *                                    the number of partitions for all 
repartition topics defined, even if they do not exist in the
+     *                                    broker yet.
+     */
+    public CopartitionedTopicsEnforcer(final LogContext logContext,
+                                       final Function<String, OptionalInt> 
topicPartitionCountProvider) {
+        this.log = logContext.logger(getClass());
+        this.topicPartitionCountProvider = topicPartitionCountProvider;
+    }
+
+    /**
+     * Enforces the number of partitions for copartitioned topics.
+     *
+     * @param copartitionedTopics          The set of copartitioned topics 
(external source topics and repartition topics).
+     * @param fixedRepartitionTopics       The set of repartition topics whose 
partition count is fixed by the topology.
+     * @param flexibleRepartitionTopics    The set of repartition topics whose 
partition count is flexible, and can be changed.
+     *
+     * @throws TopicConfigurationException If source topics are missing, or 
there are topics in copartitionTopics that are not copartitioned
+     *                                     according to 
topicPartitionCountProvider are not co-partitioned.
+     *
+     * @return A map from all repartition topics in copartitionedTopics to 
their updated partition counts.
+     */
+    public Map<String, Integer> enforce(final Set<String> copartitionedTopics,
+                                        final Set<String> 
fixedRepartitionTopics,
+                                        final Set<String> 
flexibleRepartitionTopics) throws StreamsInvalidTopologyException {
+        if (copartitionedTopics.isEmpty()) {
+            return Collections.emptyMap();
+        }
+        final Map<String, Integer> returnedPartitionCounts = new HashMap<>();
+
+        final Map<String, Integer> repartitionTopicPartitionCounts =
+            copartitionedTopics.stream()
+                .filter(x -> fixedRepartitionTopics.contains(x) || 
flexibleRepartitionTopics.contains(x))
+                .collect(Collectors.toMap(topic -> topic, 
this::getPartitionCount));
+
+        final Map<String, Integer> nonRepartitionTopicPartitions =
+            copartitionedTopics.stream().filter(topic -> 
!repartitionTopicPartitionCounts.containsKey(topic))
+                .collect(Collectors.toMap(topic -> topic, topic -> {
+                    final OptionalInt topicPartitionCount = 
topicPartitionCountProvider.apply(topic);
+                    if (topicPartitionCount.isEmpty()) {
+                        final String str = String.format("Following topics are 
missing: [%s]", topic);
+                        log.error(str);
+                        throw 
TopicConfigurationException.missingSourceTopics(str);

Review Comment:
   Since we always call `RepartitionTopics` before calling this class in 
practice, this can never happen. But that it cannot happen, depends on how this 
class is used. I think it is still good style for this "unit" to have 
well-defined behavior if it is used in a context where a source topic is 
missing. Would you propose that this class just implicitly depends on the 
validation happening somewhere else? 



##########
group-coordinator/src/test/java/org/apache/kafka/coordinator/group/streams/topics/CopartitionedTopicsEnforcerTest.java:
##########
@@ -0,0 +1,261 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.coordinator.group.streams.topics;
+
+import org.apache.kafka.common.requests.StreamsGroupHeartbeatResponse.Status;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.Utils;
+
+import org.junit.jupiter.api.Test;
+
+import java.util.Collections;
+import java.util.Map;
+import java.util.OptionalInt;
+import java.util.Set;
+import java.util.TreeMap;
+import java.util.function.Function;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+public class CopartitionedTopicsEnforcerTest {
+
+    private static final LogContext LOG_CONTEXT = new LogContext();
+
+    private static Function<String, OptionalInt> 
topicPartitionProvider(Map<String, Integer> topicPartitionCounts) {
+        return topic -> {
+            Integer a = topicPartitionCounts.get(topic);
+            return a == null ? OptionalInt.empty() : OptionalInt.of(a);
+        };
+    }
+
+    @Test
+    public void 
shouldThrowTopicConfigurationExceptionIfNoPartitionsFoundForCoPartitionedTopic()
 {
+        final String topic = "topic";
+        final Map<String, Integer> topicPartitionCounts = 
Collections.emptyMap();
+        final CopartitionedTopicsEnforcer enforcer =
+            new CopartitionedTopicsEnforcer(LOG_CONTEXT, 
topicPartitionProvider(topicPartitionCounts));
+
+        final TopicConfigurationException ex = 
assertThrows(TopicConfigurationException.class, () ->
+            enforcer.enforce(
+                Set.of(topic),
+                Set.of(),
+                Set.of()
+            ));
+        assertEquals(Status.MISSING_SOURCE_TOPICS, ex.status());
+        assertEquals("Following topics are missing: [topic]", ex.getMessage());
+    }
+
+    @Test
+    public void 
shouldThrowTopicConfigurationExceptionIfPartitionCountsForCoPartitionedTopicsDontMatch()
 {
+        final String firstSourceTopic = "first";
+        final String secondSourceTopic = "second";
+        final Map<String, Integer> topicPartitionCounts = 
Map.of(firstSourceTopic, 2, secondSourceTopic, 1);
+        final CopartitionedTopicsEnforcer enforcer =
+            new CopartitionedTopicsEnforcer(LOG_CONTEXT, 
topicPartitionProvider(topicPartitionCounts));
+
+        final TopicConfigurationException ex = 
assertThrows(TopicConfigurationException.class, () ->
+            enforcer.enforce(
+                Set.of(firstSourceTopic, secondSourceTopic),
+                Set.of(),
+                Set.of()
+            )
+        );
+        assertEquals(Status.INCORRECTLY_PARTITIONED_TOPICS, ex.status());
+        assertEquals("Following topics do not have the same number of 
partitions: " +
+            "[{first=2, second=1}]", ex.getMessage());
+    }
+
+
+    @Test
+    public void shouldEnforceCopartitioningOnRepartitionTopics() {
+        final String firstSourceTopic = "first";
+        final String secondSourceTopic = "second";
+        final String repartitionTopic = "repartitioned";
+        final Map<String, Integer> topicPartitionCounts = Map.of(
+            firstSourceTopic, 2,
+            secondSourceTopic, 2,
+            repartitionTopic, 10
+        );
+        final CopartitionedTopicsEnforcer enforcer =
+            new CopartitionedTopicsEnforcer(LOG_CONTEXT, 
topicPartitionProvider(topicPartitionCounts));
+
+        final Map<String, Integer> result =
+            enforcer.enforce(
+                Set.of(firstSourceTopic, secondSourceTopic, repartitionTopic),
+                Set.of(),
+                Set.of(repartitionTopic)
+            );
+
+        assertEquals(Map.of(repartitionTopic, 2), result);
+    }
+
+
+    @Test
+    public void 
shouldSetNumPartitionsToMaximumPartitionsWhenAllTopicsAreRepartitionTopics() {
+        final String repartitionTopic1 = "repartitionTopic1";
+        final String repartitionTopic2 = "repartitionTopic2";
+        final String repartitionTopic3 = "repartitionTopic3";
+        final Map<String, Integer> topicPartitionCounts = Map.of(
+            repartitionTopic1, 1,
+            repartitionTopic2, 15,
+            repartitionTopic3, 5
+        );
+        final CopartitionedTopicsEnforcer enforcer =
+            new CopartitionedTopicsEnforcer(LOG_CONTEXT, 
topicPartitionProvider(topicPartitionCounts));
+
+        final Map<String, Integer> result = enforcer.enforce(
+            Set.of(repartitionTopic1, repartitionTopic2, repartitionTopic3),
+            Set.of(),
+            Set.of(repartitionTopic1, repartitionTopic2, repartitionTopic3)
+        );
+
+        assertEquals(Map.of(
+            repartitionTopic1, 15,
+            repartitionTopic2, 15,
+            repartitionTopic3, 15
+        ), result);
+    }
+
+    @Test
+    public void 
shouldThrowAnExceptionIfTopicInfosWithEnforcedNumOfPartitionsHaveDifferentNumOfPartitions()
 {
+        final String repartitionTopic1 = "repartitioned-1";
+        final String repartitionTopic2 = "repartitioned-2";
+        final Map<String, Integer> topicPartitionCounts = Map.of(
+            repartitionTopic1, 10,
+            repartitionTopic2, 5
+        );
+        final CopartitionedTopicsEnforcer enforcer =
+            new CopartitionedTopicsEnforcer(LOG_CONTEXT, 
topicPartitionProvider(topicPartitionCounts));
+
+        final TopicConfigurationException ex = assertThrows(
+            TopicConfigurationException.class,
+            () -> enforcer.enforce(
+                Set.of(repartitionTopic1, repartitionTopic2),
+                Set.of(repartitionTopic1, repartitionTopic2),
+                Set.of()
+            )
+        );
+
+        final TreeMap<String, Integer> sorted = new TreeMap<>(
+            Utils.mkMap(Utils.mkEntry(repartitionTopic1, 10),
+                Utils.mkEntry(repartitionTopic2, 5))

Review Comment:
   Done



##########
group-coordinator/src/test/java/org/apache/kafka/coordinator/group/streams/topics/CopartitionedTopicsEnforcerTest.java:
##########
@@ -0,0 +1,261 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.coordinator.group.streams.topics;
+
+import org.apache.kafka.common.requests.StreamsGroupHeartbeatResponse.Status;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.Utils;
+
+import org.junit.jupiter.api.Test;
+
+import java.util.Collections;
+import java.util.Map;
+import java.util.OptionalInt;
+import java.util.Set;
+import java.util.TreeMap;
+import java.util.function.Function;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+public class CopartitionedTopicsEnforcerTest {
+
+    private static final LogContext LOG_CONTEXT = new LogContext();
+
+    private static Function<String, OptionalInt> 
topicPartitionProvider(Map<String, Integer> topicPartitionCounts) {
+        return topic -> {
+            Integer a = topicPartitionCounts.get(topic);
+            return a == null ? OptionalInt.empty() : OptionalInt.of(a);
+        };
+    }
+
+    @Test
+    public void 
shouldThrowTopicConfigurationExceptionIfNoPartitionsFoundForCoPartitionedTopic()
 {
+        final String topic = "topic";
+        final Map<String, Integer> topicPartitionCounts = 
Collections.emptyMap();
+        final CopartitionedTopicsEnforcer enforcer =
+            new CopartitionedTopicsEnforcer(LOG_CONTEXT, 
topicPartitionProvider(topicPartitionCounts));
+
+        final TopicConfigurationException ex = 
assertThrows(TopicConfigurationException.class, () ->
+            enforcer.enforce(
+                Set.of(topic),
+                Set.of(),
+                Set.of()
+            ));
+        assertEquals(Status.MISSING_SOURCE_TOPICS, ex.status());
+        assertEquals("Following topics are missing: [topic]", ex.getMessage());
+    }
+
+    @Test
+    public void 
shouldThrowTopicConfigurationExceptionIfPartitionCountsForCoPartitionedTopicsDontMatch()
 {
+        final String firstSourceTopic = "first";
+        final String secondSourceTopic = "second";
+        final Map<String, Integer> topicPartitionCounts = 
Map.of(firstSourceTopic, 2, secondSourceTopic, 1);
+        final CopartitionedTopicsEnforcer enforcer =
+            new CopartitionedTopicsEnforcer(LOG_CONTEXT, 
topicPartitionProvider(topicPartitionCounts));
+
+        final TopicConfigurationException ex = 
assertThrows(TopicConfigurationException.class, () ->
+            enforcer.enforce(
+                Set.of(firstSourceTopic, secondSourceTopic),
+                Set.of(),
+                Set.of()
+            )
+        );
+        assertEquals(Status.INCORRECTLY_PARTITIONED_TOPICS, ex.status());
+        assertEquals("Following topics do not have the same number of 
partitions: " +
+            "[{first=2, second=1}]", ex.getMessage());
+    }
+
+
+    @Test
+    public void shouldEnforceCopartitioningOnRepartitionTopics() {
+        final String firstSourceTopic = "first";
+        final String secondSourceTopic = "second";
+        final String repartitionTopic = "repartitioned";
+        final Map<String, Integer> topicPartitionCounts = Map.of(
+            firstSourceTopic, 2,
+            secondSourceTopic, 2,
+            repartitionTopic, 10
+        );
+        final CopartitionedTopicsEnforcer enforcer =
+            new CopartitionedTopicsEnforcer(LOG_CONTEXT, 
topicPartitionProvider(topicPartitionCounts));
+
+        final Map<String, Integer> result =
+            enforcer.enforce(
+                Set.of(firstSourceTopic, secondSourceTopic, repartitionTopic),
+                Set.of(),
+                Set.of(repartitionTopic)
+            );
+
+        assertEquals(Map.of(repartitionTopic, 2), result);
+    }
+
+
+    @Test
+    public void 
shouldSetNumPartitionsToMaximumPartitionsWhenAllTopicsAreRepartitionTopics() {
+        final String repartitionTopic1 = "repartitionTopic1";
+        final String repartitionTopic2 = "repartitionTopic2";
+        final String repartitionTopic3 = "repartitionTopic3";
+        final Map<String, Integer> topicPartitionCounts = Map.of(
+            repartitionTopic1, 1,
+            repartitionTopic2, 15,
+            repartitionTopic3, 5
+        );
+        final CopartitionedTopicsEnforcer enforcer =
+            new CopartitionedTopicsEnforcer(LOG_CONTEXT, 
topicPartitionProvider(topicPartitionCounts));
+
+        final Map<String, Integer> result = enforcer.enforce(
+            Set.of(repartitionTopic1, repartitionTopic2, repartitionTopic3),
+            Set.of(),
+            Set.of(repartitionTopic1, repartitionTopic2, repartitionTopic3)
+        );
+
+        assertEquals(Map.of(
+            repartitionTopic1, 15,
+            repartitionTopic2, 15,
+            repartitionTopic3, 15
+        ), result);
+    }
+
+    @Test
+    public void 
shouldThrowAnExceptionIfTopicInfosWithEnforcedNumOfPartitionsHaveDifferentNumOfPartitions()
 {
+        final String repartitionTopic1 = "repartitioned-1";
+        final String repartitionTopic2 = "repartitioned-2";
+        final Map<String, Integer> topicPartitionCounts = Map.of(
+            repartitionTopic1, 10,
+            repartitionTopic2, 5
+        );
+        final CopartitionedTopicsEnforcer enforcer =
+            new CopartitionedTopicsEnforcer(LOG_CONTEXT, 
topicPartitionProvider(topicPartitionCounts));
+
+        final TopicConfigurationException ex = assertThrows(
+            TopicConfigurationException.class,
+            () -> enforcer.enforce(
+                Set.of(repartitionTopic1, repartitionTopic2),
+                Set.of(repartitionTopic1, repartitionTopic2),
+                Set.of()
+            )
+        );
+
+        final TreeMap<String, Integer> sorted = new TreeMap<>(
+            Utils.mkMap(Utils.mkEntry(repartitionTopic1, 10),
+                Utils.mkEntry(repartitionTopic2, 5))
+        );
+        assertEquals(Status.INCORRECTLY_PARTITIONED_TOPICS, ex.status());
+        assertEquals(String.format(
+            "Following topics do not have the same number of partitions: " +
+                "[%s]", sorted), ex.getMessage());
+    }
+
+    @Test
+    public void 
shouldNotThrowAnExceptionWhenTopicInfosWithEnforcedNumOfPartitionsAreValid() {
+        final String repartitionTopic1 = "repartitioned-1";
+        final String repartitionTopic2 = "repartitioned-2";
+        final Map<String, Integer> topicPartitionCounts = Map.of(
+            repartitionTopic1, 10,
+            repartitionTopic2, 10
+        );
+        final CopartitionedTopicsEnforcer enforcer =
+            new CopartitionedTopicsEnforcer(LOG_CONTEXT, 
topicPartitionProvider(topicPartitionCounts));
+
+        final Map<String, Integer> enforced = enforcer.enforce(
+            Set.of(repartitionTopic1, repartitionTopic2),
+            Set.of(),
+            Set.of(repartitionTopic1, repartitionTopic2)
+        );
+
+        assertEquals(Map.of(
+            repartitionTopic1, 10,
+            repartitionTopic2, 10
+        ), enforced);
+    }
+
+    @Test
+    public void 
shouldThrowAnExceptionWhenNumberOfPartitionsOfNonRepartitionTopicAndRepartitionTopicWithEnforcedNumOfPartitionsDoNotMatch()
 {
+        final String repartitionTopic1 = "repartitioned-1";
+        final String firstSourceTopic = "first";
+        final Map<String, Integer> topicPartitionCounts = Map.of(
+            repartitionTopic1, 10,
+            firstSourceTopic, 2
+        );
+        final CopartitionedTopicsEnforcer enforcer =
+            new CopartitionedTopicsEnforcer(LOG_CONTEXT, 
topicPartitionProvider(topicPartitionCounts));
+
+        final TopicConfigurationException ex = assertThrows(
+            TopicConfigurationException.class,
+            () -> enforcer.enforce(
+                Set.of(repartitionTopic1, firstSourceTopic),
+                Set.of(repartitionTopic1),
+                Set.of())
+        );
+
+        assertEquals(Status.INCORRECTLY_PARTITIONED_TOPICS, ex.status());
+        assertEquals(String.format("Number of partitions [%s] " +
+                "of repartition topic [%s] " +
+                "doesn't match number of partitions [%s] of the source topic.",
+            10, repartitionTopic1, 2), ex.getMessage());
+    }
+
+    @Test
+    public void 
shouldNotThrowAnExceptionWhenNumberOfPartitionsOfNonRepartitionTopicAndRepartitionTopicWithEnforcedNumOfPartitionsMatch()
 {

Review Comment:
   Done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to