Repository: kafka Updated Branches: refs/heads/trunk c6b8de4e6 -> 124f73b17
KAFKA-2763: better stream task assignment guozhangwang When the rebalance happens each consumer reports the following information to the coordinator. * Client UUID (a unique id assigned to an instance of KafkaStreaming) * Task ids of previously running tasks * Task ids of valid local states on the client's state directory TaskAssignor does the following * Assign a task to a client which was running it previously. If there is no such client, assign a task to a client which has its valid local state. * Try to balance the load among stream threads. * A client may have more than one stream threads. The assignor tries to assign tasks to a client proportionally to the number of threads. Author: Yasuhiro Matsuda <[email protected]> Reviewers: Guozhang Wang Closes #497 from ymatsuda/task_assignment Project: http://git-wip-us.apache.org/repos/asf/kafka/repo Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/124f73b1 Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/124f73b1 Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/124f73b1 Branch: refs/heads/trunk Commit: 124f73b1747a574982e9ca491712e6758ddbacea Parents: c6b8de4 Author: Yasuhiro Matsuda <[email protected]> Authored: Wed Nov 11 16:14:27 2015 -0800 Committer: Guozhang Wang <[email protected]> Committed: Wed Nov 11 16:14:27 2015 -0800 ---------------------------------------------------------------------- .../apache/kafka/streams/KafkaStreaming.java | 5 +- .../apache/kafka/streams/StreamingConfig.java | 8 +- .../streams/processor/PartitionGrouper.java | 4 + .../apache/kafka/streams/processor/TaskId.java | 23 +- .../KafkaStreamingPartitionAssignor.java | 187 ++++++++---- .../processor/internals/StreamThread.java | 59 +++- .../internals/assignment/AssignmentInfo.java | 125 ++++++++ .../internals/assignment/ClientState.java | 72 +++++ .../internals/assignment/SubscriptionInfo.java | 128 ++++++++ .../assignment/TaskAssignmentException.java | 32 ++ .../internals/assignment/TaskAssignor.java | 195 +++++++++++++ .../KafkaStreamingPartitionAssignorTest.java | 283 ++++++++++++++++++ .../processor/internals/StreamThreadTest.java | 33 ++- .../assignment/AssginmentInfoTest.java | 45 +++ .../assignment/SubscriptionInfoTest.java | 46 +++ .../internals/assignment/TaskAssignorTest.java | 289 +++++++++++++++++++ 16 files changed, 1464 insertions(+), 70 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/KafkaStreaming.java ---------------------------------------------------------------------- diff --git a/streams/src/main/java/org/apache/kafka/streams/KafkaStreaming.java b/streams/src/main/java/org/apache/kafka/streams/KafkaStreaming.java index d274fb9..fc1fdae 100644 --- a/streams/src/main/java/org/apache/kafka/streams/KafkaStreaming.java +++ b/streams/src/main/java/org/apache/kafka/streams/KafkaStreaming.java @@ -29,6 +29,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.List; +import java.util.UUID; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -85,11 +86,13 @@ public class KafkaStreaming { private final StreamThread[] threads; private String clientId; + private final UUID uuid; private final Metrics metrics; public KafkaStreaming(TopologyBuilder builder, StreamingConfig config) throws Exception { // create the metrics this.time = new SystemTime(); + this.uuid = UUID.randomUUID(); MetricConfig metricConfig = new MetricConfig().samples(config.getInt(StreamingConfig.METRICS_NUM_SAMPLES_CONFIG)) .timeWindow(config.getLong(StreamingConfig.METRICS_SAMPLE_WINDOW_MS_CONFIG), @@ -104,7 +107,7 @@ public class KafkaStreaming { this.threads = new StreamThread[config.getInt(StreamingConfig.NUM_STREAM_THREADS_CONFIG)]; for (int i = 0; i < this.threads.length; i++) { - this.threads[i] = new StreamThread(builder, config, this.clientId, this.metrics, this.time); + this.threads[i] = new StreamThread(builder, config, this.clientId, this.uuid, this.metrics, this.time); } } http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/StreamingConfig.java ---------------------------------------------------------------------- diff --git a/streams/src/main/java/org/apache/kafka/streams/StreamingConfig.java b/streams/src/main/java/org/apache/kafka/streams/StreamingConfig.java index 88bd844..693cb0c 100644 --- a/streams/src/main/java/org/apache/kafka/streams/StreamingConfig.java +++ b/streams/src/main/java/org/apache/kafka/streams/StreamingConfig.java @@ -27,8 +27,8 @@ import org.apache.kafka.common.config.ConfigDef.Type; import org.apache.kafka.common.serialization.Deserializer; import org.apache.kafka.common.serialization.Serializer; import org.apache.kafka.streams.processor.DefaultPartitionGrouper; -import org.apache.kafka.streams.processor.PartitionGrouper; import org.apache.kafka.streams.processor.internals.KafkaStreamingPartitionAssignor; +import org.apache.kafka.streams.processor.internals.StreamThread; import java.util.Map; @@ -205,16 +205,16 @@ public class StreamingConfig extends AbstractConfig { } public static class InternalConfig { - public static final String PARTITION_GROUPER_INSTANCE = "__partition.grouper.instance__"; + public static final String STREAM_THREAD_INSTANCE = "__stream.thread.instance__"; } public StreamingConfig(Map<?, ?> props) { super(CONFIG, props); } - public Map<String, Object> getConsumerConfigs(PartitionGrouper partitionGrouper) { + public Map<String, Object> getConsumerConfigs(StreamThread streamThread) { Map<String, Object> props = getConsumerConfigs(); - props.put(StreamingConfig.InternalConfig.PARTITION_GROUPER_INSTANCE, partitionGrouper); + props.put(StreamingConfig.InternalConfig.STREAM_THREAD_INSTANCE, streamThread); props.put(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG, KafkaStreamingPartitionAssignor.class.getName()); return props; } http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/processor/PartitionGrouper.java ---------------------------------------------------------------------- diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/PartitionGrouper.java b/streams/src/main/java/org/apache/kafka/streams/processor/PartitionGrouper.java index 026ec89..00b56b3 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/PartitionGrouper.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/PartitionGrouper.java @@ -50,4 +50,8 @@ public abstract class PartitionGrouper { return partitionAssignor.taskIds(partition); } + public Set<TaskId> standbyTasks() { + return partitionAssignor.standbyTasks(); + } + } http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/processor/TaskId.java ---------------------------------------------------------------------- diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/TaskId.java b/streams/src/main/java/org/apache/kafka/streams/processor/TaskId.java index 3d474fe..5344f6c 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/TaskId.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/TaskId.java @@ -17,7 +17,9 @@ package org.apache.kafka.streams.processor; -public class TaskId { +import java.nio.ByteBuffer; + +public class TaskId implements Comparable<TaskId> { public final int topicGroupId; public final int partition; @@ -45,6 +47,15 @@ public class TaskId { } } + public void writeTo(ByteBuffer buf) { + buf.putInt(topicGroupId); + buf.putInt(partition); + } + + public static TaskId readFrom(ByteBuffer buf) { + return new TaskId(buf.getInt(), buf.getInt()); + } + @Override public boolean equals(Object o) { if (o instanceof TaskId) { @@ -61,6 +72,16 @@ public class TaskId { return (int) (n % 0xFFFFFFFFL); } + @Override + public int compareTo(TaskId other) { + return + this.topicGroupId < other.topicGroupId ? -1 : + (this.topicGroupId > other.topicGroupId ? 1 : + (this.partition < other.partition ? -1 : + (this.partition > other.partition ? 1 : + 0))); + } + public static class TaskIdFormatException extends RuntimeException { } } http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/processor/internals/KafkaStreamingPartitionAssignor.java ---------------------------------------------------------------------- diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/KafkaStreamingPartitionAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/KafkaStreamingPartitionAssignor.java index f7b14ad..35ba0ec 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/KafkaStreamingPartitionAssignor.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/KafkaStreamingPartitionAssignor.java @@ -23,37 +23,49 @@ import org.apache.kafka.common.Configurable; import org.apache.kafka.common.KafkaException; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.streams.StreamingConfig; -import org.apache.kafka.streams.processor.PartitionGrouper; import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.assignment.AssignmentInfo; +import org.apache.kafka.streams.processor.internals.assignment.ClientState; +import org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo; +import org.apache.kafka.streams.processor.internals.assignment.TaskAssignmentException; +import org.apache.kafka.streams.processor.internals.assignment.TaskAssignor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.UUID; public class KafkaStreamingPartitionAssignor implements PartitionAssignor, Configurable { private static final Logger log = LoggerFactory.getLogger(KafkaStreamingPartitionAssignor.class); - private PartitionGrouper partitionGrouper; + private StreamThread streamThread; private Map<TopicPartition, Set<TaskId>> partitionToTaskIds; + private Set<TaskId> standbyTasks; @Override public void configure(Map<String, ?> configs) { - Object o = configs.get(StreamingConfig.InternalConfig.PARTITION_GROUPER_INSTANCE); - if (o == null) - throw new KafkaException("PartitionGrouper is not specified"); + Object o = configs.get(StreamingConfig.InternalConfig.STREAM_THREAD_INSTANCE); + if (o == null) { + KafkaException ex = new KafkaException("StreamThread is not specified"); + log.error(ex.getMessage(), ex); + throw ex; + } - if (!PartitionGrouper.class.isInstance(o)) - throw new KafkaException(o.getClass().getName() + " is not an instance of " + PartitionGrouper.class.getName()); + if (!(o instanceof StreamThread)) { + KafkaException ex = new KafkaException(o.getClass().getName() + " is not an instance of " + StreamThread.class.getName()); + log.error(ex.getMessage(), ex); + throw ex; + } - partitionGrouper = (PartitionGrouper) o; - partitionGrouper.partitionAssignor(this); + streamThread = (StreamThread) o; + streamThread.partitionGrouper.partitionAssignor(this); } @Override @@ -63,38 +75,110 @@ public class KafkaStreamingPartitionAssignor implements PartitionAssignor, Confi @Override public Subscription subscription(Set<String> topics) { - return new Subscription(new ArrayList<>(topics)); + // Adds the following information to subscription + // 1. Client UUID (a unique id assigned to an instance of KafkaStreaming) + // 2. Task ids of previously running tasks + // 3. Task ids of valid local states on the client's state directory. + + Set<TaskId> prevTasks = streamThread.prevTasks(); + Set<TaskId> standbyTasks = streamThread.cachedTasks(); + standbyTasks.removeAll(prevTasks); + SubscriptionInfo data = new SubscriptionInfo(streamThread.clientUUID, prevTasks, standbyTasks); + + return new Subscription(new ArrayList<>(topics), data.encode()); } @Override public Map<String, Assignment> assign(Cluster metadata, Map<String, Subscription> subscriptions) { - Map<TaskId, Set<TopicPartition>> partitionGroups = partitionGrouper.partitionGroups(metadata); + // This assigns tasks to consumer clients in two steps. + // 1. using TaskAssignor tasks are assigned to streaming clients. + // - Assign a task to a client which was running it previously. + // If there is no such client, assign a task to a client which has its valid local state. + // - A client may have more than one stream threads. + // The assignor tries to assign tasks to a client proportionally to the number of threads. + // - We try not to assign the same set of tasks to two different clients + // We do the assignment in one-pass. The result may not satisfy above all. + // 2. within each client, tasks are assigned to consumer clients in round-robin manner. + + Map<UUID, Set<String>> consumersByClient = new HashMap<>(); + Map<UUID, ClientState<TaskId>> states = new HashMap<>(); + + // Decode subscription info + for (Map.Entry<String, Subscription> entry : subscriptions.entrySet()) { + String consumerId = entry.getKey(); + Subscription subscription = entry.getValue(); + + SubscriptionInfo info = SubscriptionInfo.decode(subscription.userData()); + + Set<String> consumers = consumersByClient.get(info.clientUUID); + if (consumers == null) { + consumers = new HashSet<>(); + consumersByClient.put(info.clientUUID, consumers); + } + consumers.add(consumerId); + + ClientState<TaskId> state = states.get(info.clientUUID); + if (state == null) { + state = new ClientState<>(); + states.put(info.clientUUID, state); + } + + state.prevActiveTasks.addAll(info.prevTasks); + state.prevAssignedTasks.addAll(info.prevTasks); + state.prevAssignedTasks.addAll(info.standbyTasks); + state.capacity = state.capacity + 1d; + } - String[] clientIds = subscriptions.keySet().toArray(new String[subscriptions.size()]); - TaskId[] taskIds = partitionGroups.keySet().toArray(new TaskId[partitionGroups.size()]); + // Get partition groups from the partition grouper + Map<TaskId, Set<TopicPartition>> partitionGroups = streamThread.partitionGrouper.partitionGroups(metadata); + states = TaskAssignor.assign(states, partitionGroups.keySet(), 0); // TODO: enable standby tasks Map<String, Assignment> assignment = new HashMap<>(); - for (int i = 0; i < clientIds.length; i++) { - List<TopicPartition> partitions = new ArrayList<>(); - List<TaskId> ids = new ArrayList<>(); - for (int j = i; j < taskIds.length; j += clientIds.length) { - TaskId taskId = taskIds[j]; - for (TopicPartition partition : partitionGroups.get(taskId)) { - partitions.add(partition); - ids.add(taskId); - } + for (Map.Entry<UUID, Set<String>> entry : consumersByClient.entrySet()) { + UUID uuid = entry.getKey(); + Set<String> consumers = entry.getValue(); + ClientState<TaskId> state = states.get(uuid); + + ArrayList<TaskId> taskIds = new ArrayList<>(state.assignedTasks.size()); + final int numActiveTasks = state.activeTasks.size(); + for (TaskId id : state.activeTasks) { + taskIds.add(id); } - ByteBuffer buf = ByteBuffer.allocate(4 + ids.size() * 8); - //version - buf.putInt(1); - // encode task ids - for (TaskId id : ids) { - buf.putInt(id.topicGroupId); - buf.putInt(id.partition); + for (TaskId id : state.assignedTasks) { + if (!state.activeTasks.contains(id)) + taskIds.add(id); + } + + final int numConsumers = consumers.size(); + List<TaskId> active = new ArrayList<>(); + Set<TaskId> standby = new HashSet<>(); + + int i = 0; + for (String consumer : consumers) { + List<TopicPartition> partitions = new ArrayList<>(); + + final int numTaskIds = taskIds.size(); + for (int j = i; j < numTaskIds; j += numConsumers) { + TaskId taskId = taskIds.get(j); + if (j < numActiveTasks) { + for (TopicPartition partition : partitionGroups.get(taskId)) { + partitions.add(partition); + active.add(taskId); + } + } else { + // no partition to a standby task + standby.add(taskId); + } + } + + AssignmentInfo data = new AssignmentInfo(active, standby); + assignment.put(consumer, new Assignment(partitions, data.encode())); + i++; + + active.clear(); + standby.clear(); } - buf.rewind(); - assignment.put(clientIds[i], new Assignment(partitions, buf)); } return assignment; @@ -103,27 +187,29 @@ public class KafkaStreamingPartitionAssignor implements PartitionAssignor, Confi @Override public void onAssignment(Assignment assignment) { List<TopicPartition> partitions = assignment.partitions(); - ByteBuffer data = assignment.userData(); - data.rewind(); + + AssignmentInfo info = AssignmentInfo.decode(assignment.userData()); + this.standbyTasks = info.standbyTasks; Map<TopicPartition, Set<TaskId>> partitionToTaskIds = new HashMap<>(); + Iterator<TaskId> iter = info.activeTasks.iterator(); + for (TopicPartition partition : partitions) { + Set<TaskId> taskIds = partitionToTaskIds.get(partition); + if (taskIds == null) { + taskIds = new HashSet<>(); + partitionToTaskIds.put(partition, taskIds); + } - // check version - int version = data.getInt(); - if (version == 1) { - for (TopicPartition partition : partitions) { - Set<TaskId> taskIds = partitionToTaskIds.get(partition); - if (taskIds == null) { - taskIds = new HashSet<>(); - partitionToTaskIds.put(partition, taskIds); - } - // decode a task id - taskIds.add(new TaskId(data.getInt(), data.getInt())); + if (iter.hasNext()) { + taskIds.add(iter.next()); + } else { + TaskAssignmentException ex = new TaskAssignmentException( + "failed to find a task id for the partition=" + partition.toString() + + ", partitions=" + partitions.size() + ", assignmentInfo=" + info.toString() + ); + log.error(ex.getMessage(), ex); + throw ex; } - } else { - KafkaException ex = new KafkaException("unknown assignment data version: " + version); - log.error(ex.getMessage(), ex); - throw ex; } this.partitionToTaskIds = partitionToTaskIds; } @@ -132,4 +218,7 @@ public class KafkaStreamingPartitionAssignor implements PartitionAssignor, Confi return partitionToTaskIds.get(partition); } + public Set<TaskId> standbyTasks() { + return standbyTasks; + } } http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java ---------------------------------------------------------------------- diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java index ba81421..06e5951 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java @@ -59,6 +59,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.UUID; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; @@ -67,16 +68,18 @@ public class StreamThread extends Thread { private static final Logger log = LoggerFactory.getLogger(StreamThread.class); private static final AtomicInteger STREAMING_THREAD_ID_SEQUENCE = new AtomicInteger(1); - private final AtomicBoolean running; + public final PartitionGrouper partitionGrouper; + public final UUID clientUUID; protected final StreamingConfig config; protected final TopologyBuilder builder; - protected final PartitionGrouper partitionGrouper; protected final Producer<byte[], byte[]> producer; protected final Consumer<byte[], byte[]> consumer; protected final Consumer<byte[], byte[]> restoreConsumer; + private final AtomicBoolean running; private final Map<TaskId, StreamTask> tasks; + private final Set<TaskId> prevTasks; private final String clientId; private final Time time; private final File stateDir; @@ -108,9 +111,10 @@ public class StreamThread extends Thread { public StreamThread(TopologyBuilder builder, StreamingConfig config, String clientId, + UUID clientUUID, Metrics metrics, Time time) throws Exception { - this(builder, config, null , null, null, clientId, metrics, time); + this(builder, config, null , null, null, clientId, clientUUID, metrics, time); } StreamThread(TopologyBuilder builder, @@ -119,6 +123,7 @@ public class StreamThread extends Thread { Consumer<byte[], byte[]> consumer, Consumer<byte[], byte[]> restoreConsumer, String clientId, + UUID clientUUID, Metrics metrics, Time time) throws Exception { super("StreamThread-" + STREAMING_THREAD_ID_SEQUENCE.getAndIncrement()); @@ -126,6 +131,7 @@ public class StreamThread extends Thread { this.config = config; this.builder = builder; this.clientId = clientId; + this.clientUUID = clientUUID; this.partitionGrouper = config.getConfiguredInstance(StreamingConfig.PARTITION_GROUPER_CLASS_CONFIG, PartitionGrouper.class); this.partitionGrouper.topicGroups(builder.topicGroups()); @@ -136,6 +142,7 @@ public class StreamThread extends Thread { // initialize the task list this.tasks = new HashMap<>(); + this.prevTasks = new HashSet<>(); // read in task specific config values this.stateDir = new File(this.config.getString(StreamingConfig.STATE_DIR_CONFIG)); @@ -164,7 +171,7 @@ public class StreamThread extends Thread { private Consumer<byte[], byte[]> createConsumer() { log.info("Creating consumer client for stream thread [" + this.getName() + "]"); - return new KafkaConsumer<>(config.getConsumerConfigs(partitionGrouper), + return new KafkaConsumer<>(config.getConsumerConfigs(this), new ByteArrayDeserializer(), new ByteArrayDeserializer()); } @@ -415,6 +422,43 @@ public class StreamThread extends Thread { } } + /** + * Returns ids of tasks that were being executed before the rebalance. + */ + public Set<TaskId> prevTasks() { + return prevTasks; + } + + /** + * Returns ids of tasks whose states are kept on the local storage. + */ + public Set<TaskId> cachedTasks() { + // A client could contain some inactive tasks whose states are still kept on the local storage in the following scenarios: + // 1) the client is actively maintaining standby tasks by maintaining their states from the change log. + // 2) the client has just got some tasks migrated out of itself to other clients while these task states + // have not been cleaned up yet (this can happen in a rolling bounce upgrade, for example). + + HashSet<TaskId> tasks = new HashSet<>(); + + File[] stateDirs = stateDir.listFiles(); + if (stateDirs != null) { + for (File dir : stateDirs) { + try { + TaskId id = TaskId.parse(dir.getName()); + // if the checkpoint file exists, the state is valid. + if (new File(dir, ProcessorStateManager.CHECKPOINT_FILE_NAME).exists()) + tasks.add(id); + + } catch (TaskId.TaskIdFormatException e) { + // there may be some unknown files that sits in the same directory, + // we should ignore these files instead trying to delete them as well + } + } + } + + return tasks; + } + protected StreamTask createStreamTask(TaskId id, Collection<TopicPartition> partitionsForTask) { sensors.taskCreationSensor.record(); @@ -465,11 +509,10 @@ public class StreamThread extends Thread { } sensors.taskDestructionSensor.record(); } - tasks.clear(); - } + prevTasks.clear(); + prevTasks.addAll(tasks.keySet()); - public PartitionGrouper partitionGrouper() { - return partitionGrouper; + tasks.clear(); } private void ensureCopartitioning(Collection<Set<String>> copartitionGroups) { http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentInfo.java ---------------------------------------------------------------------- diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentInfo.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentInfo.java new file mode 100644 index 0000000..d82dd7d --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentInfo.java @@ -0,0 +1,125 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.processor.internals.assignment; + +import org.apache.kafka.streams.processor.TaskId; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +public class AssignmentInfo { + + private static final Logger log = LoggerFactory.getLogger(AssignmentInfo.class); + + public final int version; + public final List<TaskId> activeTasks; // each element corresponds to a partition + public final Set<TaskId> standbyTasks; + + public AssignmentInfo(List<TaskId> activeTasks, Set<TaskId> standbyTasks) { + this(1, activeTasks, standbyTasks); + } + + protected AssignmentInfo(int version, List<TaskId> activeTasks, Set<TaskId> standbyTasks) { + this.version = version; + this.activeTasks = activeTasks; + this.standbyTasks = standbyTasks; + } + + public ByteBuffer encode() { + if (version == 1) { + ByteBuffer buf = ByteBuffer.allocate(4 + 4 + activeTasks.size() * 8 + 4 + standbyTasks.size() * 8); + // Encode version + buf.putInt(1); + // Encode active tasks + buf.putInt(activeTasks.size()); + for (TaskId id : activeTasks) { + id.writeTo(buf); + } + // Encode standby tasks + buf.putInt(standbyTasks.size()); + for (TaskId id : standbyTasks) { + id.writeTo(buf); + } + buf.rewind(); + + return buf; + + } else { + TaskAssignmentException ex = new TaskAssignmentException("unable to encode assignment data: version=" + version); + log.error(ex.getMessage(), ex); + throw ex; + } + } + + public static AssignmentInfo decode(ByteBuffer data) { + // ensure we are at the beginning of the ByteBuffer + data.rewind(); + + // Decode version + int version = data.getInt(); + if (version == 1) { + // Decode active tasks + int count = data.getInt(); + List<TaskId> activeTasks = new ArrayList<>(count); + for (int i = 0; i < count; i++) { + activeTasks.add(TaskId.readFrom(data)); + } + // Decode standby tasks + count = data.getInt(); + Set<TaskId> standbyTasks = new HashSet<>(count); + for (int i = 0; i < count; i++) { + standbyTasks.add(TaskId.readFrom(data)); + } + + return new AssignmentInfo(activeTasks, standbyTasks); + + } else { + TaskAssignmentException ex = new TaskAssignmentException("unknown assignment data version: " + version); + log.error(ex.getMessage(), ex); + throw ex; + } + } + + @Override + public int hashCode() { + return version ^ activeTasks.hashCode() ^ standbyTasks.hashCode(); + } + + @Override + public boolean equals(Object o) { + if (o instanceof AssignmentInfo) { + AssignmentInfo other = (AssignmentInfo) o; + return this.version == other.version && + this.activeTasks.equals(other.activeTasks) && + this.standbyTasks.equals(other.standbyTasks); + } else { + return false; + } + } + + @Override + public String toString() { + return "[version=" + version + ", active tasks=" + activeTasks.size() + ", standby tasks=" + standbyTasks.size() + "]"; + } + +} http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java ---------------------------------------------------------------------- diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java new file mode 100644 index 0000000..a0f6179 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java @@ -0,0 +1,72 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.processor.internals.assignment; + +import java.util.HashSet; +import java.util.Set; + +public class ClientState<T> { + + public final static double COST_ACTIVE = 0.1; + public final static double COST_STANDBY = 0.2; + public final static double COST_LOAD = 0.5; + + public final Set<T> activeTasks; + public final Set<T> assignedTasks; + public final Set<T> prevActiveTasks; + public final Set<T> prevAssignedTasks; + + public double capacity; + public double cost; + + public ClientState() { + this(0d); + } + + public ClientState(double capacity) { + this(new HashSet<T>(), new HashSet<T>(), new HashSet<T>(), new HashSet<T>(), capacity); + } + + private ClientState(Set<T> activeTasks, Set<T> assignedTasks, Set<T> prevActiveTasks, Set<T> prevAssignedTasks, double capacity) { + this.activeTasks = activeTasks; + this.assignedTasks = assignedTasks; + this.prevActiveTasks = prevActiveTasks; + this.prevAssignedTasks = prevAssignedTasks; + this.capacity = capacity; + this.cost = 0d; + } + + public ClientState<T> copy() { + return new ClientState<>(new HashSet<>(activeTasks), new HashSet<>(assignedTasks), + new HashSet<>(prevActiveTasks), new HashSet<>(prevAssignedTasks), capacity); + } + + public void assign(T taskId, boolean active) { + if (active) + activeTasks.add(taskId); + + assignedTasks.add(taskId); + + double cost = COST_LOAD; + cost = prevAssignedTasks.remove(taskId) ? COST_STANDBY : cost; + cost = prevActiveTasks.remove(taskId) ? COST_ACTIVE : cost; + + this.cost += cost; + } + +} http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfo.java ---------------------------------------------------------------------- diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfo.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfo.java new file mode 100644 index 0000000..54042b9 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfo.java @@ -0,0 +1,128 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.processor.internals.assignment; + +import org.apache.kafka.streams.processor.TaskId; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.ByteBuffer; +import java.util.HashSet; +import java.util.Set; +import java.util.UUID; + +public class SubscriptionInfo { + + private static final Logger log = LoggerFactory.getLogger(SubscriptionInfo.class); + + public final int version; + public final UUID clientUUID; + public final Set<TaskId> prevTasks; + public final Set<TaskId> standbyTasks; + + public SubscriptionInfo(UUID clientUUID, Set<TaskId> prevTasks, Set<TaskId> standbyTasks) { + this(1, clientUUID, prevTasks, standbyTasks); + } + + private SubscriptionInfo(int version, UUID clientUUID, Set<TaskId> prevTasks, Set<TaskId> standbyTasks) { + this.version = version; + this.clientUUID = clientUUID; + this.prevTasks = prevTasks; + this.standbyTasks = standbyTasks; + } + + public ByteBuffer encode() { + if (version == 1) { + ByteBuffer buf = ByteBuffer.allocate(4 + 16 + 4 + prevTasks.size() * 8 + 4 + standbyTasks.size() * 8); + // version + buf.putInt(1); + // encode client UUID + buf.putLong(clientUUID.getMostSignificantBits()); + buf.putLong(clientUUID.getLeastSignificantBits()); + // encode ids of previously running tasks + buf.putInt(prevTasks.size()); + for (TaskId id : prevTasks) { + id.writeTo(buf); + } + // encode ids of cached tasks + buf.putInt(standbyTasks.size()); + for (TaskId id : standbyTasks) { + id.writeTo(buf); + } + buf.rewind(); + + return buf; + + } else { + TaskAssignmentException ex = new TaskAssignmentException("unable to encode subscription data: version=" + version); + log.error(ex.getMessage(), ex); + throw ex; + } + } + + public static SubscriptionInfo decode(ByteBuffer data) { + // ensure we are at the beginning of the ByteBuffer + data.rewind(); + + // Decode version + int version = data.getInt(); + if (version == 1) { + // Decode client UUID + UUID clientUUID = new UUID(data.getLong(), data.getLong()); + // Decode previously active tasks + Set<TaskId> prevTasks = new HashSet<>(); + int numPrevs = data.getInt(); + for (int i = 0; i < numPrevs; i++) { + TaskId id = TaskId.readFrom(data); + prevTasks.add(id); + } + // Decode previously cached tasks + Set<TaskId> standbyTasks = new HashSet<>(); + int numCached = data.getInt(); + for (int i = 0; i < numCached; i++) { + standbyTasks.add(TaskId.readFrom(data)); + } + + return new SubscriptionInfo(version, clientUUID, prevTasks, standbyTasks); + + } else { + TaskAssignmentException ex = new TaskAssignmentException("unable to decode subscription data: version=" + version); + log.error(ex.getMessage(), ex); + throw ex; + } + } + + @Override + public int hashCode() { + return version ^ clientUUID.hashCode() ^ prevTasks.hashCode() ^ standbyTasks.hashCode(); + } + + @Override + public boolean equals(Object o) { + if (o instanceof SubscriptionInfo) { + SubscriptionInfo other = (SubscriptionInfo) o; + return this.version == other.version && + this.clientUUID.equals(other.clientUUID) && + this.prevTasks.equals(other.prevTasks) && + this.standbyTasks.equals(other.standbyTasks); + } else { + return false; + } + } + +} http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignmentException.java ---------------------------------------------------------------------- diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignmentException.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignmentException.java new file mode 100644 index 0000000..839a6c2 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignmentException.java @@ -0,0 +1,32 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import org.apache.kafka.common.KafkaException; + +/** + * The run time exception class for stream task assignments + */ +public class TaskAssignmentException extends KafkaException { + + private final static long serialVersionUID = 1L; + + public TaskAssignmentException(String message) { + super(message); + } + +} http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java ---------------------------------------------------------------------- diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java new file mode 100644 index 0000000..d1e0782 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java @@ -0,0 +1,195 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.processor.internals.assignment; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Random; +import java.util.Set; + +public class TaskAssignor<C, T extends Comparable<T>> { + + private static final Logger log = LoggerFactory.getLogger(TaskAssignor.class); + + public static <C, T extends Comparable<T>> Map<C, ClientState<T>> assign(Map<C, ClientState<T>> states, Set<T> tasks, int numStandbyReplicas) { + long seed = 0L; + for (C client : states.keySet()) { + seed += client.hashCode(); + } + + TaskAssignor<C, T> assignor = new TaskAssignor<>(states, tasks, seed); + assignor.assignTasks(); + if (numStandbyReplicas > 0) + assignor.assignStandbyTasks(numStandbyReplicas); + + return assignor.states; + } + + private final Random rand; + private final Map<C, ClientState<T>> states; + private final Set<TaskPair<T>> taskPairs; + private final int maxNumTaskPairs; + private final ArrayList<T> tasks; + + private TaskAssignor(Map<C, ClientState<T>> states, Set<T> tasks, long randomSeed) { + this.rand = new Random(randomSeed); + this.states = new HashMap<>(); + for (Map.Entry<C, ClientState<T>> entry : states.entrySet()) { + this.states.put(entry.getKey(), entry.getValue().copy()); + } + this.tasks = new ArrayList<>(tasks); + + int numTasks = tasks.size(); + this.maxNumTaskPairs = numTasks * (numTasks - 1) / 2; + this.taskPairs = new HashSet<>(this.maxNumTaskPairs); + } + + public void assignTasks() { + assignTasks(true); + } + + public void assignStandbyTasks(int numStandbyReplicas) { + int numReplicas = Math.min(numStandbyReplicas, states.size() - 1); + for (int i = 0; i < numReplicas; i++) { + assignTasks(false); + } + } + + private void assignTasks(boolean active) { + Collections.shuffle(this.tasks, rand); + + for (T task : tasks) { + ClientState<T> state = findClientFor(task); + + if (state != null) { + state.assign(task, active); + } else { + TaskAssignmentException ex = new TaskAssignmentException("failed to find an assignable client"); + log.error(ex.getMessage(), ex); + throw ex; + } + } + } + + private ClientState<T> findClientFor(T task) { + boolean checkTaskPairs = taskPairs.size() < maxNumTaskPairs; + + ClientState<T> state = findClientByAdditionCost(task, checkTaskPairs); + + if (state == null && checkTaskPairs) + state = findClientByAdditionCost(task, false); + + if (state != null) + addTaskPairs(task, state); + + return state; + } + + private ClientState<T> findClientByAdditionCost(T task, boolean checkTaskPairs) { + ClientState<T> candidate = null; + double candidateAdditionCost = 0d; + + for (ClientState<T> state : states.values()) { + if (!state.assignedTasks.contains(task)) { + // if checkTaskPairs flag is on, skip this client if this task doesn't introduce a new task combination + if (checkTaskPairs && !state.assignedTasks.isEmpty() && !hasNewTaskPair(task, state)) + continue; + + double additionCost = computeAdditionCost(task, state); + if (candidate == null || + (additionCost < candidateAdditionCost || + (additionCost == candidateAdditionCost && state.cost < candidate.cost))) { + candidate = state; + candidateAdditionCost = additionCost; + } + } + } + + return candidate; + } + + private void addTaskPairs(T task, ClientState<T> state) { + for (T other : state.assignedTasks) { + taskPairs.add(pair(task, other)); + } + } + + private boolean hasNewTaskPair(T task, ClientState<T> state) { + for (T other : state.assignedTasks) { + if (!taskPairs.contains(pair(task, other))) + return true; + } + return false; + } + + private double computeAdditionCost(T task, ClientState<T> state) { + double cost = Math.floor((double) state.assignedTasks.size() / state.capacity); + + if (state.prevAssignedTasks.contains(task)) { + if (state.prevActiveTasks.contains(task)) { + cost += ClientState.COST_ACTIVE; + } else { + cost += ClientState.COST_STANDBY; + } + } else { + cost += ClientState.COST_LOAD; + } + + return cost; + } + + private TaskPair<T> pair(T task1, T task2) { + if (task1.compareTo(task2) < 0) { + return new TaskPair<>(task1, task2); + } else { + return new TaskPair<>(task2, task1); + } + } + + private static class TaskPair<T> { + public final T task1; + public final T task2; + + public TaskPair(T task1, T task2) { + this.task1 = task1; + this.task2 = task2; + } + + @Override + public int hashCode() { + return task1.hashCode() ^ task2.hashCode(); + } + + @SuppressWarnings("unchecked") + @Override + public boolean equals(Object o) { + if (o instanceof TaskPair) { + TaskPair<T> other = (TaskPair<T>) o; + return this.task1.equals(other.task1) && this.task2.equals(other.task2); + } + return false; + } + } + +} http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/test/java/org/apache/kafka/streams/processor/internals/KafkaStreamingPartitionAssignorTest.java ---------------------------------------------------------------------- diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/KafkaStreamingPartitionAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/KafkaStreamingPartitionAssignorTest.java new file mode 100644 index 0000000..86434fb --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/KafkaStreamingPartitionAssignorTest.java @@ -0,0 +1,283 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.MockConsumer; +import org.apache.kafka.clients.consumer.OffsetResetStrategy; +import org.apache.kafka.clients.consumer.internals.PartitionAssignor; +import org.apache.kafka.clients.producer.MockProducer; +import org.apache.kafka.clients.producer.Producer; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.common.utils.SystemTime; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.StreamingConfig; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.TopologyBuilder; +import org.apache.kafka.streams.processor.internals.assignment.AssignmentInfo; +import org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo; +import org.apache.kafka.test.MockProcessorSupplier; +import org.junit.Test; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.UUID; + +import static org.junit.Assert.assertEquals; + +public class KafkaStreamingPartitionAssignorTest { + + private TopicPartition t1p0 = new TopicPartition("topic1", 0); + private TopicPartition t1p1 = new TopicPartition("topic1", 1); + private TopicPartition t1p2 = new TopicPartition("topic1", 2); + private TopicPartition t2p0 = new TopicPartition("topic2", 0); + private TopicPartition t2p1 = new TopicPartition("topic2", 1); + private TopicPartition t2p2 = new TopicPartition("topic2", 2); + private TopicPartition t2p3 = new TopicPartition("topic2", 3); + + private List<PartitionInfo> infos = Arrays.asList( + new PartitionInfo("topic1", 0, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic1", 1, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic1", 2, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic2", 0, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic2", 1, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic2", 2, Node.noNode(), new Node[0], new Node[0]) + ); + + private Cluster metadata = new Cluster(Arrays.asList(Node.noNode()), infos, Collections.<String>emptySet()); + + private ByteBuffer subscriptionUserData() { + UUID uuid = UUID.randomUUID(); + ByteBuffer buf = ByteBuffer.allocate(4 + 16 + 4 + 4); + // version + buf.putInt(1); + // encode client clientUUID + buf.putLong(uuid.getMostSignificantBits()); + buf.putLong(uuid.getLeastSignificantBits()); + // previously running tasks + buf.putInt(0); + // cached tasks + buf.putInt(0); + buf.rewind(); + return buf; + } + + private final TaskId task0 = new TaskId(0, 0); + private final TaskId task1 = new TaskId(0, 1); + private final TaskId task2 = new TaskId(0, 2); + private final TaskId task3 = new TaskId(0, 3); + + private Properties configProps() { + return new Properties() { + { + setProperty(StreamingConfig.KEY_SERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.ByteArraySerializer"); + setProperty(StreamingConfig.KEY_DESERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.ByteArrayDeserializer"); + setProperty(StreamingConfig.VALUE_SERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.ByteArraySerializer"); + setProperty(StreamingConfig.VALUE_DESERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.ByteArrayDeserializer"); + setProperty(StreamingConfig.TIMESTAMP_EXTRACTOR_CLASS_CONFIG, "org.apache.kafka.test.MockTimestampExtractor"); + setProperty(StreamingConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:2171"); + setProperty(StreamingConfig.BUFFERED_RECORDS_PER_PARTITION_CONFIG, "3"); + } + }; + } + + private static class TestStreamTask extends StreamTask { + public boolean committed = false; + + public TestStreamTask(TaskId id, + Consumer<byte[], byte[]> consumer, + Producer<byte[], byte[]> producer, + Consumer<byte[], byte[]> restoreConsumer, + Collection<TopicPartition> partitions, + ProcessorTopology topology, + StreamingConfig config) { + super(id, consumer, producer, restoreConsumer, partitions, topology, config, null); + } + + @Override + public void commit() { + super.commit(); + committed = true; + } + } + + private ByteArraySerializer serializer = new ByteArraySerializer(); + + @SuppressWarnings("unchecked") + @Test + public void testSubscription() throws Exception { + StreamingConfig config = new StreamingConfig(configProps()); + + MockProducer<byte[], byte[]> producer = new MockProducer<>(true, serializer, serializer); + MockConsumer<byte[], byte[]> consumer = new MockConsumer<>(OffsetResetStrategy.EARLIEST); + MockConsumer<byte[], byte[]> mockRestoreConsumer = new MockConsumer<>(OffsetResetStrategy.LATEST); + + TopologyBuilder builder = new TopologyBuilder(); + builder.addSource("source1", "topic1"); + builder.addSource("source2", "topic2"); + builder.addProcessor("processor", new MockProcessorSupplier(), "source1", "source2"); + + final Set<TaskId> prevTasks = Utils.mkSet( + new TaskId(0, 1), new TaskId(1, 1), new TaskId(2, 1)); + final Set<TaskId> cachedTasks = Utils.mkSet( + new TaskId(0, 1), new TaskId(1, 1), new TaskId(2, 1), + new TaskId(0, 2), new TaskId(1, 2), new TaskId(2, 2)); + + UUID uuid = UUID.randomUUID(); + StreamThread thread = new StreamThread(builder, config, producer, consumer, mockRestoreConsumer, "test", uuid, new Metrics(), new SystemTime()) { + @Override + public Set<TaskId> prevTasks() { + return prevTasks; + } + @Override + public Set<TaskId> cachedTasks() { + return cachedTasks; + } + }; + + KafkaStreamingPartitionAssignor partitionAssignor = new KafkaStreamingPartitionAssignor(); + partitionAssignor.configure( + Collections.singletonMap(StreamingConfig.InternalConfig.STREAM_THREAD_INSTANCE, thread) + ); + + PartitionAssignor.Subscription subscription = partitionAssignor.subscription(Utils.mkSet("topic1", "topic2")); + + assertEquals(Utils.mkList("topic1", "topic2"), subscription.topics()); + + Set<TaskId> standbyTasks = new HashSet<>(cachedTasks); + standbyTasks.removeAll(prevTasks); + + SubscriptionInfo info = new SubscriptionInfo(uuid, prevTasks, standbyTasks); + assertEquals(info.encode(), subscription.userData()); + } + + @Test + public void testAssign() throws Exception { + StreamingConfig config = new StreamingConfig(configProps()); + + MockProducer<byte[], byte[]> producer = new MockProducer<>(true, serializer, serializer); + MockConsumer<byte[], byte[]> consumer = new MockConsumer<>(OffsetResetStrategy.EARLIEST); + MockConsumer<byte[], byte[]> mockRestoreConsumer = new MockConsumer<>(OffsetResetStrategy.LATEST); + + TopologyBuilder builder = new TopologyBuilder(); + builder.addSource("source1", "topic1"); + builder.addSource("source2", "topic2"); + builder.addProcessor("processor", new MockProcessorSupplier(), "source1", "source2"); + + final Set<TaskId> prevTasks10 = Utils.mkSet(task0); + final Set<TaskId> prevTasks11 = Utils.mkSet(task1); + final Set<TaskId> prevTasks20 = Utils.mkSet(task2); + final Set<TaskId> standbyTasks10 = Utils.mkSet(task1); + final Set<TaskId> standbyTasks11 = Utils.mkSet(task2); + final Set<TaskId> standbyTasks20 = Utils.mkSet(task0); + + UUID uuid1 = UUID.randomUUID(); + UUID uuid2 = UUID.randomUUID(); + + StreamThread thread10 = new StreamThread(builder, config, producer, consumer, mockRestoreConsumer, "test", uuid1, new Metrics(), new SystemTime()); + + KafkaStreamingPartitionAssignor partitionAssignor = new KafkaStreamingPartitionAssignor(); + partitionAssignor.configure( + Collections.singletonMap(StreamingConfig.InternalConfig.STREAM_THREAD_INSTANCE, thread10) + ); + + Map<String, PartitionAssignor.Subscription> subscriptions = new HashMap<>(); + subscriptions.put("consumer10", + new PartitionAssignor.Subscription(Utils.mkList("topic1", "topic2"), new SubscriptionInfo(uuid1, prevTasks10, standbyTasks10).encode())); + subscriptions.put("consumer11", + new PartitionAssignor.Subscription(Utils.mkList("topic1", "topic2"), new SubscriptionInfo(uuid1, prevTasks11, standbyTasks11).encode())); + subscriptions.put("consumer20", + new PartitionAssignor.Subscription(Utils.mkList("topic1", "topic2"), new SubscriptionInfo(uuid2, prevTasks20, standbyTasks20).encode())); + + Map<String, PartitionAssignor.Assignment> assignments = partitionAssignor.assign(metadata, subscriptions); + + // check assigned partitions + + assertEquals(Utils.mkSet(Utils.mkSet(t1p0, t2p0), Utils.mkSet(t1p1, t2p1)), + Utils.mkSet(new HashSet<>(assignments.get("consumer10").partitions()), new HashSet<>(assignments.get("consumer11").partitions()))); + assertEquals(Utils.mkSet(t1p2, t2p2), new HashSet<>(assignments.get("consumer20").partitions())); + + // check assignment info + + List<TaskId> activeTasks = new ArrayList<>(); + for (TopicPartition partition : assignments.get("consumer10").partitions()) { + activeTasks.add(new TaskId(0, partition.partition())); + } + assertEquals(activeTasks, AssignmentInfo.decode(assignments.get("consumer10").userData()).activeTasks); + + activeTasks.clear(); + for (TopicPartition partition : assignments.get("consumer11").partitions()) { + activeTasks.add(new TaskId(0, partition.partition())); + } + assertEquals(activeTasks, AssignmentInfo.decode(assignments.get("consumer11").userData()).activeTasks); + + activeTasks.clear(); + for (TopicPartition partition : assignments.get("consumer20").partitions()) { + activeTasks.add(new TaskId(0, partition.partition())); + } + assertEquals(activeTasks, AssignmentInfo.decode(assignments.get("consumer20").userData()).activeTasks); + } + + @Test + public void testOnAssignment() throws Exception { + StreamingConfig config = new StreamingConfig(configProps()); + + MockProducer<byte[], byte[]> producer = new MockProducer<>(true, serializer, serializer); + MockConsumer<byte[], byte[]> consumer = new MockConsumer<>(OffsetResetStrategy.EARLIEST); + MockConsumer<byte[], byte[]> mockRestoreConsumer = new MockConsumer<>(OffsetResetStrategy.LATEST); + + TopologyBuilder builder = new TopologyBuilder(); + builder.addSource("source1", "topic1"); + builder.addSource("source2", "topic2"); + builder.addProcessor("processor", new MockProcessorSupplier(), "source1", "source2"); + + UUID uuid = UUID.randomUUID(); + + StreamThread thread = new StreamThread(builder, config, producer, consumer, mockRestoreConsumer, "test", uuid, new Metrics(), new SystemTime()); + + KafkaStreamingPartitionAssignor partitionAssignor = new KafkaStreamingPartitionAssignor(); + partitionAssignor.configure( + Collections.singletonMap(StreamingConfig.InternalConfig.STREAM_THREAD_INSTANCE, thread) + ); + + List<TaskId> activeTaskList = Utils.mkList(task0, task3); + Set<TaskId> standbyTasks = Utils.mkSet(task1, task2); + AssignmentInfo info = new AssignmentInfo(activeTaskList, standbyTasks); + PartitionAssignor.Assignment assignment = new PartitionAssignor.Assignment(Utils.mkList(t1p0, t2p3), info.encode()); + partitionAssignor.onAssignment(assignment); + + assertEquals(Utils.mkSet(task0), partitionAssignor.taskIds(t1p0)); + assertEquals(Utils.mkSet(task3), partitionAssignor.taskIds(t2p3)); + assertEquals(standbyTasks, partitionAssignor.standbyTasks()); + } + +} http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java ---------------------------------------------------------------------- diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java index 909df13..54d0a18 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java @@ -38,13 +38,13 @@ import org.apache.kafka.common.utils.MockTime; import org.apache.kafka.common.utils.SystemTime; import org.apache.kafka.common.utils.Utils; import org.apache.kafka.streams.StreamingConfig; -import org.apache.kafka.streams.processor.PartitionGrouper; import org.apache.kafka.streams.processor.TaskId; import org.apache.kafka.streams.processor.TopologyBuilder; import org.apache.kafka.test.MockProcessorSupplier; import org.junit.Test; import java.io.File; +import java.nio.ByteBuffer; import java.nio.file.Files; import java.util.Arrays; import java.util.Collection; @@ -55,9 +55,12 @@ import java.util.List; import java.util.Map; import java.util.Properties; import java.util.Set; +import java.util.UUID; public class StreamThreadTest { + private UUID uuid = UUID.randomUUID(); + private TopicPartition t1p1 = new TopicPartition("topic1", 1); private TopicPartition t1p2 = new TopicPartition("topic1", 2); private TopicPartition t2p1 = new TopicPartition("topic2", 1); @@ -79,7 +82,24 @@ public class StreamThreadTest { private Cluster metadata = new Cluster(Arrays.asList(Node.noNode()), infos, Collections.<String>emptySet()); - PartitionAssignor.Subscription subscription = new PartitionAssignor.Subscription(Arrays.asList("topic1", "topic2", "topic3")); + private final PartitionAssignor.Subscription subscription = + new PartitionAssignor.Subscription(Arrays.asList("topic1", "topic2", "topic3"), subscriptionUserData()); + + private ByteBuffer subscriptionUserData() { + UUID uuid = UUID.randomUUID(); + ByteBuffer buf = ByteBuffer.allocate(4 + 16 + 4 + 4); + // version + buf.putInt(1); + // encode client clientUUID + buf.putLong(uuid.getMostSignificantBits()); + buf.putLong(uuid.getLeastSignificantBits()); + // previously running tasks + buf.putInt(0); + // cached tasks + buf.putInt(0); + buf.rewind(); + return buf; + } // task0 is unused private final TaskId task1 = new TaskId(0, 1); @@ -139,7 +159,7 @@ public class StreamThreadTest { builder.addSource("source3", "topic3"); builder.addProcessor("processor", new MockProcessorSupplier(), "source2", "source3"); - StreamThread thread = new StreamThread(builder, config, producer, consumer, mockRestoreConsumer, "test", new Metrics(), new SystemTime()) { + StreamThread thread = new StreamThread(builder, config, producer, consumer, mockRestoreConsumer, "test", uuid, new Metrics(), new SystemTime()) { @Override protected StreamTask createStreamTask(TaskId id, Collection<TopicPartition> partitionsForTask) { ProcessorTopology topology = builder.build(id.topicGroupId); @@ -259,7 +279,7 @@ public class StreamThreadTest { TopologyBuilder builder = new TopologyBuilder(); builder.addSource("source1", "topic1"); - StreamThread thread = new StreamThread(builder, config, producer, consumer, mockRestoreConsumer, "test", new Metrics(), mockTime) { + StreamThread thread = new StreamThread(builder, config, producer, consumer, mockRestoreConsumer, "test", uuid, new Metrics(), mockTime) { @Override public void maybeClean() { super.maybeClean(); @@ -381,7 +401,7 @@ public class StreamThreadTest { TopologyBuilder builder = new TopologyBuilder(); builder.addSource("source1", "topic1"); - StreamThread thread = new StreamThread(builder, config, producer, consumer, mockRestoreConsumer, "test", new Metrics(), mockTime) { + StreamThread thread = new StreamThread(builder, config, producer, consumer, mockRestoreConsumer, "test", uuid, new Metrics(), mockTime) { @Override public void maybeCommit() { super.maybeCommit(); @@ -448,12 +468,11 @@ public class StreamThreadTest { } private void initPartitionGrouper(StreamThread thread) { - PartitionGrouper partitionGrouper = thread.partitionGrouper(); KafkaStreamingPartitionAssignor partitionAssignor = new KafkaStreamingPartitionAssignor(); partitionAssignor.configure( - Collections.singletonMap(StreamingConfig.InternalConfig.PARTITION_GROUPER_INSTANCE, partitionGrouper) + Collections.singletonMap(StreamingConfig.InternalConfig.STREAM_THREAD_INSTANCE, thread) ); Map<String, PartitionAssignor.Assignment> assignments = http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssginmentInfoTest.java ---------------------------------------------------------------------- diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssginmentInfoTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssginmentInfoTest.java new file mode 100644 index 0000000..58e0af9 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssginmentInfoTest.java @@ -0,0 +1,45 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.processor.internals.assignment; + +import org.apache.kafka.streams.processor.TaskId; +import org.junit.Test; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static org.junit.Assert.assertEquals; + +public class AssginmentInfoTest { + + @Test + public void testEncodeDecode() { + List<TaskId> activeTasks = + Arrays.asList(new TaskId(0, 0), new TaskId(0, 0), new TaskId(0, 1), new TaskId(1, 0)); + Set<TaskId> standbyTasks = + new HashSet<>(Arrays.asList(new TaskId(1, 1), new TaskId(2, 0))); + + AssignmentInfo info = new AssignmentInfo(activeTasks, standbyTasks); + AssignmentInfo decoded = AssignmentInfo.decode(info.encode()); + + assertEquals(info, decoded); + } + +} http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfoTest.java ---------------------------------------------------------------------- diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfoTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfoTest.java new file mode 100644 index 0000000..acc9a9d --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfoTest.java @@ -0,0 +1,46 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.processor.internals.assignment; + +import org.apache.kafka.streams.processor.TaskId; +import org.junit.Test; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; +import java.util.UUID; + +import static org.junit.Assert.assertEquals; + +public class SubscriptionInfoTest { + + @Test + public void testEncodeDecode() { + UUID clientUUID = UUID.randomUUID(); + Set<TaskId> activeTasks = + new HashSet<>(Arrays.asList(new TaskId(0, 0), new TaskId(0, 1), new TaskId(1, 0))); + Set<TaskId> standbyTasks = + new HashSet<>(Arrays.asList(new TaskId(1, 1), new TaskId(2, 0))); + + SubscriptionInfo info = new SubscriptionInfo(clientUUID, activeTasks, standbyTasks); + SubscriptionInfo decoded = SubscriptionInfo.decode(info.encode()); + + assertEquals(info, decoded); + } + +} http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorTest.java ---------------------------------------------------------------------- diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorTest.java new file mode 100644 index 0000000..28364ab --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorTest.java @@ -0,0 +1,289 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.processor.internals.assignment; + +import static org.apache.kafka.common.utils.Utils.mkList; +import static org.apache.kafka.common.utils.Utils.mkSet; +import org.junit.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class TaskAssignorTest { + + @Test + public void testAssignWithoutStandby() { + HashMap<Integer, ClientState<Integer>> states = new HashMap<>(); + for (int i = 0; i < 6; i++) { + states.put(i, new ClientState<Integer>(1d)); + } + Set<Integer> tasks; + Map<Integer, ClientState<Integer>> assignments; + int numActiveTasks; + int numAssignedTasks; + + // # of clients and # of tasks are equal. + tasks = mkSet(0, 1, 2, 3, 4, 5); + assignments = TaskAssignor.assign(states, tasks, 0); + numActiveTasks = 0; + numAssignedTasks = 0; + for (ClientState<Integer> assignment : assignments.values()) { + numActiveTasks += assignment.activeTasks.size(); + numAssignedTasks += assignment.assignedTasks.size(); + assertEquals(1, assignment.activeTasks.size()); + assertEquals(1, assignment.assignedTasks.size()); + } + assertEquals(tasks.size(), numActiveTasks); + assertEquals(tasks.size(), numAssignedTasks); + + // # of clients < # of tasks + tasks = mkSet(0, 1, 2, 3, 4, 5, 6, 7); + assignments = TaskAssignor.assign(states, tasks, 0); + numActiveTasks = 0; + numAssignedTasks = 0; + for (ClientState<Integer> assignment : assignments.values()) { + numActiveTasks += assignment.activeTasks.size(); + numAssignedTasks += assignment.assignedTasks.size(); + assertTrue(1 <= assignment.activeTasks.size()); + assertTrue(2 >= assignment.activeTasks.size()); + assertTrue(1 <= assignment.assignedTasks.size()); + assertTrue(2 >= assignment.assignedTasks.size()); + } + assertEquals(tasks.size(), numActiveTasks); + assertEquals(tasks.size(), numAssignedTasks); + + // # of clients > # of tasks + tasks = mkSet(0, 1, 2, 3); + assignments = TaskAssignor.assign(states, tasks, 0); + numActiveTasks = 0; + numAssignedTasks = 0; + for (ClientState<Integer> assignment : assignments.values()) { + numActiveTasks += assignment.activeTasks.size(); + numAssignedTasks += assignment.assignedTasks.size(); + assertTrue(0 <= assignment.activeTasks.size()); + assertTrue(1 >= assignment.activeTasks.size()); + assertTrue(0 <= assignment.assignedTasks.size()); + assertTrue(1 >= assignment.assignedTasks.size()); + } + assertEquals(tasks.size(), numActiveTasks); + assertEquals(tasks.size(), numAssignedTasks); + } + + @Test + public void testAssignWithStandby() { + HashMap<Integer, ClientState<Integer>> states = new HashMap<>(); + for (int i = 0; i < 6; i++) { + states.put(i, new ClientState<Integer>(1d)); + } + Set<Integer> tasks; + Map<Integer, ClientState<Integer>> assignments; + int numActiveTasks; + int numAssignedTasks; + + // # of clients and # of tasks are equal. + tasks = mkSet(0, 1, 2, 3, 4, 5); + + // 1 standby replicas. + numActiveTasks = 0; + numAssignedTasks = 0; + assignments = TaskAssignor.assign(states, tasks, 1); + for (ClientState<Integer> assignment : assignments.values()) { + numActiveTasks += assignment.activeTasks.size(); + numAssignedTasks += assignment.assignedTasks.size(); + assertEquals(1, assignment.activeTasks.size()); + assertEquals(2, assignment.assignedTasks.size()); + } + assertEquals(tasks.size(), numActiveTasks); + assertEquals(tasks.size() * 2, numAssignedTasks); + + // # of clients < # of tasks + tasks = mkSet(0, 1, 2, 3, 4, 5, 6, 7); + + // 1 standby replicas. + assignments = TaskAssignor.assign(states, tasks, 1); + numActiveTasks = 0; + numAssignedTasks = 0; + for (ClientState<Integer> assignment : assignments.values()) { + numActiveTasks += assignment.activeTasks.size(); + numAssignedTasks += assignment.assignedTasks.size(); + assertTrue(1 <= assignment.activeTasks.size()); + assertTrue(2 >= assignment.activeTasks.size()); + assertTrue(2 <= assignment.assignedTasks.size()); + assertTrue(3 >= assignment.assignedTasks.size()); + } + assertEquals(tasks.size(), numActiveTasks); + assertEquals(tasks.size() * 2, numAssignedTasks); + + // # of clients > # of tasks + tasks = mkSet(0, 1, 2, 3); + + // 1 standby replicas. + assignments = TaskAssignor.assign(states, tasks, 1); + numActiveTasks = 0; + numAssignedTasks = 0; + for (ClientState<Integer> assignment : assignments.values()) { + numActiveTasks += assignment.activeTasks.size(); + numAssignedTasks += assignment.assignedTasks.size(); + assertTrue(0 <= assignment.activeTasks.size()); + assertTrue(1 >= assignment.activeTasks.size()); + assertTrue(1 <= assignment.assignedTasks.size()); + assertTrue(2 >= assignment.assignedTasks.size()); + } + assertEquals(tasks.size(), numActiveTasks); + assertEquals(tasks.size() * 2, numAssignedTasks); + + // # of clients >> # of tasks + tasks = mkSet(0, 1); + + // 1 standby replicas. + assignments = TaskAssignor.assign(states, tasks, 1); + numActiveTasks = 0; + numAssignedTasks = 0; + for (ClientState<Integer> assignment : assignments.values()) { + numActiveTasks += assignment.activeTasks.size(); + numAssignedTasks += assignment.assignedTasks.size(); + assertTrue(0 <= assignment.activeTasks.size()); + assertTrue(1 >= assignment.activeTasks.size()); + assertTrue(0 <= assignment.assignedTasks.size()); + assertTrue(1 >= assignment.assignedTasks.size()); + } + assertEquals(tasks.size(), numActiveTasks); + assertEquals(tasks.size() * 2, numAssignedTasks); + + // 2 standby replicas. + assignments = TaskAssignor.assign(states, tasks, 2); + numActiveTasks = 0; + numAssignedTasks = 0; + for (ClientState<Integer> assignment : assignments.values()) { + numActiveTasks += assignment.activeTasks.size(); + numAssignedTasks += assignment.assignedTasks.size(); + assertTrue(0 <= assignment.activeTasks.size()); + assertTrue(1 >= assignment.activeTasks.size()); + assertTrue(1 == assignment.assignedTasks.size()); + } + assertEquals(tasks.size(), numActiveTasks); + assertEquals(tasks.size() * 3, numAssignedTasks); + + // 3 standby replicas. + assignments = TaskAssignor.assign(states, tasks, 3); + numActiveTasks = 0; + numAssignedTasks = 0; + for (ClientState<Integer> assignment : assignments.values()) { + numActiveTasks += assignment.activeTasks.size(); + numAssignedTasks += assignment.assignedTasks.size(); + assertTrue(0 <= assignment.activeTasks.size()); + assertTrue(1 >= assignment.activeTasks.size()); + assertTrue(1 <= assignment.assignedTasks.size()); + assertTrue(2 >= assignment.assignedTasks.size()); + } + assertEquals(tasks.size(), numActiveTasks); + assertEquals(tasks.size() * 4, numAssignedTasks); + } + + @Test + public void testStickiness() { + List<Integer> tasks; + Map<Integer, ClientState<Integer>> states; + Map<Integer, ClientState<Integer>> assignments; + int i; + + // # of clients and # of tasks are equal. + tasks = mkList(0, 1, 2, 3, 4, 5); + Collections.shuffle(tasks); + states = new HashMap<>(); + i = 0; + for (int task : tasks) { + ClientState<Integer> state = new ClientState<>(1d); + state.prevActiveTasks.add(task); + state.prevAssignedTasks.add(task); + states.put(i++, state); + } + assignments = TaskAssignor.assign(states, mkSet(0, 1, 2, 3, 4, 5), 0); + for (int client : states.keySet()) { + Set<Integer> oldActive = states.get(client).prevActiveTasks; + Set<Integer> oldAssigned = states.get(client).prevAssignedTasks; + Set<Integer> newActive = assignments.get(client).activeTasks; + Set<Integer> newAssigned = assignments.get(client).assignedTasks; + + assertEquals(oldActive, newActive); + assertEquals(oldAssigned, newAssigned); + } + + // # of clients > # of tasks + tasks = mkList(0, 1, 2, 3, -1, -1); + Collections.shuffle(tasks); + states = new HashMap<>(); + i = 0; + for (int task : tasks) { + ClientState<Integer> state = new ClientState<>(1d); + if (task >= 0) { + state.prevActiveTasks.add(task); + state.prevAssignedTasks.add(task); + } + states.put(i++, state); + } + assignments = TaskAssignor.assign(states, mkSet(0, 1, 2, 3), 0); + for (int client : states.keySet()) { + Set<Integer> oldActive = states.get(client).prevActiveTasks; + Set<Integer> oldAssigned = states.get(client).prevAssignedTasks; + Set<Integer> newActive = assignments.get(client).activeTasks; + Set<Integer> newAssigned = assignments.get(client).assignedTasks; + + assertEquals(oldActive, newActive); + assertEquals(oldAssigned, newAssigned); + } + + // # of clients < # of tasks + List<Set<Integer>> taskSets = mkList(mkSet(0, 1), mkSet(2, 3), mkSet(4, 5), mkSet(6, 7), mkSet(8, 9), mkSet(10, 11)); + Collections.shuffle(taskSets); + states = new HashMap<>(); + i = 0; + for (Set<Integer> taskSet : taskSets) { + ClientState<Integer> state = new ClientState<>(1d); + state.prevActiveTasks.addAll(taskSet); + state.prevAssignedTasks.addAll(taskSet); + states.put(i++, state); + } + assignments = TaskAssignor.assign(states, mkSet(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11), 0); + for (int client : states.keySet()) { + Set<Integer> oldActive = states.get(client).prevActiveTasks; + Set<Integer> oldAssigned = states.get(client).prevAssignedTasks; + Set<Integer> newActive = assignments.get(client).activeTasks; + Set<Integer> newAssigned = assignments.get(client).assignedTasks; + + Set<Integer> intersection = new HashSet<>(); + + intersection.addAll(oldActive); + intersection.retainAll(newActive); + assertTrue(intersection.size() > 0); + + intersection.clear(); + intersection.addAll(oldAssigned); + intersection.retainAll(newAssigned); + assertTrue(intersection.size() > 0); + } + } + +}
