http://git-wip-us.apache.org/repos/asf/flink/blob/29a6e995/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/messages/KvStateRequest.java ---------------------------------------------------------------------- diff --git a/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/messages/KvStateRequest.java b/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/messages/KvStateRequest.java new file mode 100644 index 0000000..eb33bce --- /dev/null +++ b/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/messages/KvStateRequest.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.queryablestate.messages; + +import org.apache.flink.runtime.query.KvStateID; +import org.apache.flink.runtime.state.internal.InternalKvState; +import org.apache.flink.util.Preconditions; + +/** + * A {@link InternalKvState} instance request for a specific key and namespace. + */ +public final class KvStateRequest { + + /** ID for this request. */ + private final long requestId; + + /** ID of the requested KvState instance. */ + private final KvStateID kvStateId; + + /** Serialized key and namespace to request from the KvState instance. */ + private final byte[] serializedKeyAndNamespace; + + /** + * Creates a KvState instance request. + * + * @param requestId ID for this request + * @param kvStateId ID of the requested KvState instance + * @param serializedKeyAndNamespace Serialized key and namespace to request from the KvState + * instance + */ + public KvStateRequest(long requestId, KvStateID kvStateId, byte[] serializedKeyAndNamespace) { + this.requestId = requestId; + this.kvStateId = Preconditions.checkNotNull(kvStateId, "KvStateID"); + this.serializedKeyAndNamespace = Preconditions.checkNotNull(serializedKeyAndNamespace, "Serialized key and namespace"); + } + + /** + * Returns the request ID. + * + * @return Request ID + */ + public long getRequestId() { + return requestId; + } + + /** + * Returns the ID of the requested KvState instance. + * + * @return ID of the requested KvState instance + */ + public KvStateID getKvStateId() { + return kvStateId; + } + + /** + * Returns the serialized key and namespace to request from the KvState + * instance. + * + * @return Serialized key and namespace to request from the KvState instance + */ + public byte[] getSerializedKeyAndNamespace() { + return serializedKeyAndNamespace; + } + + @Override + public String toString() { + return "KvStateRequest{" + + "requestId=" + requestId + + ", kvStateId=" + kvStateId + + ", serializedKeyAndNamespace.length=" + serializedKeyAndNamespace.length + + '}'; + } +}
http://git-wip-us.apache.org/repos/asf/flink/blob/29a6e995/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/messages/KvStateRequestFailure.java ---------------------------------------------------------------------- diff --git a/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/messages/KvStateRequestFailure.java b/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/messages/KvStateRequestFailure.java new file mode 100644 index 0000000..4015d79 --- /dev/null +++ b/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/messages/KvStateRequestFailure.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.queryablestate.messages; + +/** + * A failure response to a {@link KvStateRequest}. + */ +public final class KvStateRequestFailure { + + /** ID of the request responding to. */ + private final long requestId; + + /** Failure cause. Not allowed to be a user type. */ + private final Throwable cause; + + /** + * Creates a failure response to a {@link KvStateRequest}. + * + * @param requestId ID for the request responding to + * @param cause Failure cause (not allowed to be a user type) + */ + public KvStateRequestFailure(long requestId, Throwable cause) { + this.requestId = requestId; + this.cause = cause; + } + + /** + * Returns the request ID responding to. + * + * @return Request ID responding to + */ + public long getRequestId() { + return requestId; + } + + /** + * Returns the failure cause. + * + * @return Failure cause + */ + public Throwable getCause() { + return cause; + } + + @Override + public String toString() { + return "KvStateRequestFailure{" + + "requestId=" + requestId + + ", cause=" + cause + + '}'; + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/29a6e995/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/messages/KvStateRequestResult.java ---------------------------------------------------------------------- diff --git a/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/messages/KvStateRequestResult.java b/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/messages/KvStateRequestResult.java new file mode 100644 index 0000000..6bf2397 --- /dev/null +++ b/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/messages/KvStateRequestResult.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.queryablestate.messages; + +import org.apache.flink.util.Preconditions; + +/** + * A successful response to a {@link KvStateRequest} containing the serialized + * result for the requested key and namespace. + */ +public final class KvStateRequestResult { + + /** ID of the request responding to. */ + private final long requestId; + + /** + * Serialized result for the requested key and namespace. If no result was + * available for the specified key and namespace, this is <code>null</code>. + */ + private final byte[] serializedResult; + + /** + * Creates a successful {@link KvStateRequestResult} response. + * + * @param requestId ID of the request responding to + * @param serializedResult Serialized result or <code>null</code> if none + */ + public KvStateRequestResult(long requestId, byte[] serializedResult) { + this.requestId = requestId; + this.serializedResult = Preconditions.checkNotNull(serializedResult, "Serialization result"); + } + + /** + * Returns the request ID responding to. + * + * @return Request ID responding to + */ + public long getRequestId() { + return requestId; + } + + /** + * Returns the serialized result or <code>null</code> if none available. + * + * @return Serialized result or <code>null</code> if none available. + */ + public byte[] getSerializedResult() { + return serializedResult; + } + + @Override + public String toString() { + return "KvStateRequestResult{" + + "requestId=" + requestId + + ", serializedResult.length=" + serializedResult.length + + '}'; + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/29a6e995/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/network/messages/MessageSerializer.java ---------------------------------------------------------------------- diff --git a/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/network/messages/MessageSerializer.java b/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/network/messages/MessageSerializer.java new file mode 100644 index 0000000..32bca64 --- /dev/null +++ b/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/network/messages/MessageSerializer.java @@ -0,0 +1,332 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.queryablestate.network.messages; + +import org.apache.flink.queryablestate.messages.KvStateRequest; +import org.apache.flink.queryablestate.messages.KvStateRequestFailure; +import org.apache.flink.queryablestate.messages.KvStateRequestResult; +import org.apache.flink.runtime.query.KvStateID; +import org.apache.flink.util.AbstractID; +import org.apache.flink.util.Preconditions; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBufAllocator; +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBufInputStream; +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBufOutputStream; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutput; +import java.io.ObjectOutputStream; + +/** + * Serialization and deserialization of messages exchanged between + * {@link org.apache.flink.queryablestate.client.KvStateClient client} and + * {@link org.apache.flink.queryablestate.server.KvStateServerImpl server}. + * + * <p>The binary messages have the following format: + * + * <pre> + * <------ Frame -------------------------> + * +----------------------------------------+ + * | HEADER (8) | PAYLOAD (VAR) | + * +------------------+----------------------------------------+ + * | FRAME LENGTH (4) | VERSION (4) | TYPE (4) | CONTENT (VAR) | + * +------------------+----------------------------------------+ + * </pre> + * + * <p>The concrete content of a message depends on the {@link MessageType}. + */ +public final class MessageSerializer { + + /** The serialization version ID. */ + private static final int VERSION = 0x79a1b710; + + /** Byte length of the header. */ + private static final int HEADER_LENGTH = 2 * Integer.BYTES; + + /** Byte length of the request id. */ + private static final int REQUEST_ID_SIZE = Long.BYTES; + + // ------------------------------------------------------------------------ + // Serialization + // ------------------------------------------------------------------------ + + /** + * Allocates a buffer and serializes the KvState request into it. + * + * @param alloc ByteBuf allocator for the buffer to + * serialize message into + * @param requestId ID for this request + * @param kvStateId ID of the requested KvState instance + * @param serializedKeyAndNamespace Serialized key and namespace to request + * from the KvState instance. + * @return Serialized KvState request message + */ + public static ByteBuf serializeKvStateRequest( + ByteBufAllocator alloc, + long requestId, + KvStateID kvStateId, + byte[] serializedKeyAndNamespace) { + + // Header + request ID + KvState ID + Serialized namespace + int frameLength = HEADER_LENGTH + REQUEST_ID_SIZE + AbstractID.SIZE + (Integer.BYTES + serializedKeyAndNamespace.length); + ByteBuf buf = alloc.ioBuffer(frameLength + 4); // +4 for frame length + + buf.writeInt(frameLength); + + writeHeader(buf, MessageType.REQUEST); + + buf.writeLong(requestId); + buf.writeLong(kvStateId.getLowerPart()); + buf.writeLong(kvStateId.getUpperPart()); + buf.writeInt(serializedKeyAndNamespace.length); + buf.writeBytes(serializedKeyAndNamespace); + + return buf; + } + + /** + * Allocates a buffer and serializes the KvState request result into it. + * + * @param alloc ByteBuf allocator for the buffer to serialize message into + * @param requestId ID for this request + * @param serializedResult Serialized Result + * @return Serialized KvState request result message + */ + public static ByteBuf serializeKvStateRequestResult( + ByteBufAllocator alloc, + long requestId, + byte[] serializedResult) { + + Preconditions.checkNotNull(serializedResult, "Serialized result"); + + // Header + request ID + serialized result + int frameLength = HEADER_LENGTH + REQUEST_ID_SIZE + 4 + serializedResult.length; + + // TODO: 10/5/17 there was a bug all this time? + ByteBuf buf = alloc.ioBuffer(frameLength + 4); + + buf.writeInt(frameLength); + writeHeader(buf, MessageType.REQUEST_RESULT); + buf.writeLong(requestId); + + buf.writeInt(serializedResult.length); + buf.writeBytes(serializedResult); + + return buf; + } + + /** + * Serializes the exception containing the failure message sent to the + * {@link org.apache.flink.queryablestate.client.KvStateClient} in case of + * protocol related errors. + * + * @param alloc The {@link ByteBufAllocator} used to allocate the buffer to serialize the message into. + * @param requestId The id of the request to which the message refers to. + * @param cause The exception thrown at the server. + * @return A {@link ByteBuf} containing the serialized message. + */ + public static ByteBuf serializeKvStateRequestFailure( + final ByteBufAllocator alloc, + final long requestId, + final Throwable cause) throws IOException { + + final ByteBuf buf = alloc.ioBuffer(); + + // Frame length is set at the end + buf.writeInt(0); + writeHeader(buf, MessageType.REQUEST_FAILURE); + buf.writeLong(requestId); + + try (ByteBufOutputStream bbos = new ByteBufOutputStream(buf); + ObjectOutput out = new ObjectOutputStream(bbos)) { + out.writeObject(cause); + } + + // Set frame length + int frameLength = buf.readableBytes() - Integer.BYTES; + buf.setInt(0, frameLength); + return buf; + } + + /** + * Serializes the failure message sent to the + * {@link org.apache.flink.queryablestate.client.KvStateClient} in case of + * server related errors. + * + * @param alloc The {@link ByteBufAllocator} used to allocate the buffer to serialize the message into. + * @param cause The exception thrown at the server. + * @return The failure message. + */ + public static ByteBuf serializeServerFailure( + final ByteBufAllocator alloc, + final Throwable cause) throws IOException { + + final ByteBuf buf = alloc.ioBuffer(); + + // Frame length is set at end + buf.writeInt(0); + writeHeader(buf, MessageType.SERVER_FAILURE); + + try (ByteBufOutputStream bbos = new ByteBufOutputStream(buf); + ObjectOutput out = new ObjectOutputStream(bbos)) { + out.writeObject(cause); + } + + // Set frame length + int frameLength = buf.readableBytes() - Integer.BYTES; + buf.setInt(0, frameLength); + return buf; + } + + /** + * Helper for serializing the header. + * + * @param buf The {@link ByteBuf} to serialize the header into. + * @param messageType The {@link MessageType} of the message this header refers to. + */ + private static void writeHeader(final ByteBuf buf, final MessageType messageType) { + buf.writeInt(VERSION); + buf.writeInt(messageType.ordinal()); + } + + // ------------------------------------------------------------------------ + // Deserialization + // ------------------------------------------------------------------------ + + /** + * De-serializes the header and returns the {@link MessageType}. + * <pre> + * <b>The buffer is expected to be at the header position.</b> + * </pre> + * @param buf The {@link ByteBuf} containing the serialized header. + * @return The message type. + * @throws IllegalStateException If unexpected message version or message type. + */ + public static MessageType deserializeHeader(final ByteBuf buf) { + + // checking the version + int version = buf.readInt(); + Preconditions.checkState(version == VERSION, + "Version Mismatch: Found " + version + ", Expected: " + VERSION + '.'); + + // fetching the message type + int msgType = buf.readInt(); + MessageType[] values = MessageType.values(); + Preconditions.checkState(msgType >= 0 && msgType <= values.length, + "Illegal message type with index " + msgType + '.'); + return values[msgType]; + } + + /** + * Deserializes the KvState request message. + * + * <p><strong>Important</strong>: the returned buffer is sliced from the + * incoming ByteBuf stream and retained. Therefore, it needs to be recycled + * by the consumer. + * + * @param buf Buffer to deserialize (expected to be positioned after header) + * @return Deserialized KvStateRequest + */ + public static KvStateRequest deserializeKvStateRequest(ByteBuf buf) { + long requestId = buf.readLong(); + KvStateID kvStateId = new KvStateID(buf.readLong(), buf.readLong()); + + // Serialized key and namespace + int length = buf.readInt(); + + if (length < 0) { + throw new IllegalArgumentException("Negative length for serialized key and namespace. " + + "This indicates a serialization error."); + } + + // Copy the buffer in order to be able to safely recycle the ByteBuf + byte[] serializedKeyAndNamespace = new byte[length]; + if (length > 0) { + buf.readBytes(serializedKeyAndNamespace); + } + + return new KvStateRequest(requestId, kvStateId, serializedKeyAndNamespace); + } + + /** + * Deserializes the KvState request result. + * + * @param buf Buffer to deserialize (expected to be positioned after header) + * @return Deserialized KvStateRequestResult + */ + public static KvStateRequestResult deserializeKvStateRequestResult(ByteBuf buf) { + long requestId = buf.readLong(); + + // Serialized KvState + int length = buf.readInt(); + + if (length < 0) { + throw new IllegalArgumentException("Negative length for serialized result. " + + "This indicates a serialization error."); + } + + byte[] serializedValue = new byte[length]; + + if (length > 0) { + buf.readBytes(serializedValue); + } + + return new KvStateRequestResult(requestId, serializedValue); + } + + /** + * De-serializes the {@link KvStateRequestFailure} sent to the + * {@link org.apache.flink.queryablestate.client.KvStateClient} in case of + * protocol related errors. + * <pre> + * <b>The buffer is expected to be at the correct position.</b> + * </pre> + * @param buf The {@link ByteBuf} containing the serialized failure message. + * @return The failure message. + */ + public static KvStateRequestFailure deserializeKvStateRequestFailure(final ByteBuf buf) throws IOException, ClassNotFoundException { + long requestId = buf.readLong(); + + Throwable cause; + try (ByteBufInputStream bis = new ByteBufInputStream(buf); + ObjectInputStream in = new ObjectInputStream(bis)) { + cause = (Throwable) in.readObject(); + } + return new KvStateRequestFailure(requestId, cause); + } + + /** + * De-serializes the failure message sent to the + * {@link org.apache.flink.queryablestate.client.KvStateClient} in case of + * server related errors. + * <pre> + * <b>The buffer is expected to be at the correct position.</b> + * </pre> + * @param buf The {@link ByteBuf} containing the serialized failure message. + * @return The failure message. + */ + public static Throwable deserializeServerFailure(final ByteBuf buf) throws IOException, ClassNotFoundException { + try (ByteBufInputStream bis = new ByteBufInputStream(buf); + ObjectInputStream in = new ObjectInputStream(bis)) { + return (Throwable) in.readObject(); + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/29a6e995/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/network/messages/MessageType.java ---------------------------------------------------------------------- diff --git a/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/network/messages/MessageType.java b/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/network/messages/MessageType.java new file mode 100644 index 0000000..4e4435d --- /dev/null +++ b/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/network/messages/MessageType.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.queryablestate.network.messages; + +/** + * Expected message types during the communication between + * {@link org.apache.flink.queryablestate.client.KvStateClient state client} and + * {@link org.apache.flink.queryablestate.server.KvStateServerImpl state server}. + */ +public enum MessageType { + + /** The message is a request. */ + REQUEST, + + /** The message is a successful response. */ + REQUEST_RESULT, + + /** The message indicates a protocol-related failure. */ + REQUEST_FAILURE, + + /** The message indicates a server failure. */ + SERVER_FAILURE +} http://git-wip-us.apache.org/repos/asf/flink/blob/29a6e995/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/server/ChunkedByteBuf.java ---------------------------------------------------------------------- diff --git a/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/server/ChunkedByteBuf.java b/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/server/ChunkedByteBuf.java new file mode 100644 index 0000000..f10969e --- /dev/null +++ b/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/server/ChunkedByteBuf.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.queryablestate.server; + +import org.apache.flink.util.Preconditions; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; +import org.apache.flink.shaded.netty4.io.netty.handler.stream.ChunkedInput; +import org.apache.flink.shaded.netty4.io.netty.handler.stream.ChunkedWriteHandler; + +/** + * A {@link ByteBuf} instance to be consumed in chunks by {@link ChunkedWriteHandler}, + * respecting the high and low watermarks. + * + * @see <a href="http://normanmaurer.me/presentations/2014-facebook-eng-netty/slides.html#10.0">Low/High Watermarks</a> + */ +public class ChunkedByteBuf implements ChunkedInput<ByteBuf> { + + /** The buffer to chunk. */ + private final ByteBuf buf; + + /** Size of chunks. */ + private final int chunkSize; + + /** Closed flag. */ + private boolean isClosed; + + /** End of input flag. */ + private boolean isEndOfInput; + + public ChunkedByteBuf(ByteBuf buf, int chunkSize) { + this.buf = Preconditions.checkNotNull(buf, "Buffer"); + Preconditions.checkArgument(chunkSize > 0, "Non-positive chunk size"); + this.chunkSize = chunkSize; + } + + @Override + public boolean isEndOfInput() throws Exception { + return isClosed || isEndOfInput; + } + + @Override + public void close() throws Exception { + if (!isClosed) { + // If we did not consume the whole buffer yet, we have to release + // it here. Otherwise, it's the responsibility of the consumer. + if (!isEndOfInput) { + buf.release(); + } + + isClosed = true; + } + } + + @Override + public ByteBuf readChunk(ChannelHandlerContext ctx) throws Exception { + if (isClosed) { + return null; + } else if (buf.readableBytes() <= chunkSize) { + isEndOfInput = true; + + // Don't retain as the consumer is responsible to release it + return buf.slice(); + } else { + // Return a chunk sized slice of the buffer. The ref count is + // shared with the original buffer. That's why we need to retain + // a reference here. + return buf.readSlice(chunkSize).retain(); + } + } + + @Override + public String toString() { + return "ChunkedByteBuf{" + + "buf=" + buf + + ", chunkSize=" + chunkSize + + ", isClosed=" + isClosed + + ", isEndOfInput=" + isEndOfInput + + '}'; + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/29a6e995/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/server/KvStateServerHandler.java ---------------------------------------------------------------------- diff --git a/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/server/KvStateServerHandler.java b/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/server/KvStateServerHandler.java new file mode 100644 index 0000000..9a31fca --- /dev/null +++ b/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/server/KvStateServerHandler.java @@ -0,0 +1,308 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.queryablestate.server; + +import org.apache.flink.queryablestate.UnknownKeyOrNamespace; +import org.apache.flink.queryablestate.UnknownKvStateID; +import org.apache.flink.queryablestate.messages.KvStateRequest; +import org.apache.flink.queryablestate.network.messages.MessageSerializer; +import org.apache.flink.queryablestate.network.messages.MessageType; +import org.apache.flink.runtime.query.KvStateRegistry; +import org.apache.flink.runtime.query.netty.KvStateRequestStats; +import org.apache.flink.runtime.state.internal.InternalKvState; +import org.apache.flink.util.ExceptionUtils; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFuture; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFutureListener; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandler; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandlerAdapter; +import org.apache.flink.shaded.netty4.io.netty.util.ReferenceCountUtil; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Objects; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; + +/** + * This handler dispatches asynchronous tasks, which query {@link InternalKvState} + * instances and write the result to the channel. + * + * <p>The network threads receive the message, deserialize it and dispatch the + * query task. The actual query is handled in a separate thread as it might + * otherwise block the network threads (file I/O etc.). + */ [email protected] +public class KvStateServerHandler extends ChannelInboundHandlerAdapter { + + private static final Logger LOG = LoggerFactory.getLogger(KvStateServerHandler.class); + + /** KvState registry holding references to the KvState instances. */ + private final KvStateRegistry registry; + + /** Thread pool for query execution. */ + private final ExecutorService queryExecutor; + + /** Exposed server statistics. */ + private final KvStateRequestStats stats; + + /** + * Create the handler. + * + * @param kvStateRegistry Registry to query. + * @param queryExecutor Thread pool for query execution. + * @param stats Exposed server statistics. + */ + public KvStateServerHandler( + KvStateRegistry kvStateRegistry, + ExecutorService queryExecutor, + KvStateRequestStats stats) { + + this.registry = Objects.requireNonNull(kvStateRegistry, "KvStateRegistry"); + this.queryExecutor = Objects.requireNonNull(queryExecutor, "Query thread pool"); + this.stats = Objects.requireNonNull(stats, "KvStateRequestStats"); + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + stats.reportActiveConnection(); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + stats.reportInactiveConnection(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + KvStateRequest request = null; + + try { + ByteBuf buf = (ByteBuf) msg; + MessageType msgType = MessageSerializer.deserializeHeader(buf); + + if (msgType == MessageType.REQUEST) { + // ------------------------------------------------------------ + // Request + // ------------------------------------------------------------ + request = MessageSerializer.deserializeKvStateRequest(buf); + + stats.reportRequest(); + + InternalKvState<?> kvState = registry.getKvState(request.getKvStateId()); + + if (kvState != null) { + // Execute actual query async, because it is possibly + // blocking (e.g. file I/O). + // + // A submission failure is not treated as fatal. + queryExecutor.submit(new AsyncKvStateQueryTask(ctx, request, kvState, stats)); + } else { + ByteBuf unknown = MessageSerializer.serializeKvStateRequestFailure( + ctx.alloc(), + request.getRequestId(), + new UnknownKvStateID(request.getKvStateId())); + + ctx.writeAndFlush(unknown); + + stats.reportFailedRequest(); + } + } else { + // ------------------------------------------------------------ + // Unexpected + // ------------------------------------------------------------ + ByteBuf failure = MessageSerializer.serializeServerFailure( + ctx.alloc(), + new IllegalArgumentException("Unexpected message type " + msgType + + ". KvStateServerHandler expects " + + MessageType.REQUEST + " messages.")); + + ctx.writeAndFlush(failure); + } + } catch (Throwable t) { + String stringifiedCause = ExceptionUtils.stringifyException(t); + + ByteBuf err; + if (request != null) { + String errMsg = "Failed to handle incoming request with ID " + + request.getRequestId() + ". Caused by: " + stringifiedCause; + err = MessageSerializer.serializeKvStateRequestFailure( + ctx.alloc(), + request.getRequestId(), + new RuntimeException(errMsg)); + + stats.reportFailedRequest(); + } else { + String errMsg = "Failed to handle incoming message. Caused by: " + stringifiedCause; + err = MessageSerializer.serializeServerFailure( + ctx.alloc(), + new RuntimeException(errMsg)); + } + + ctx.writeAndFlush(err); + } finally { + // IMPORTANT: We have to always recycle the incoming buffer. + // Otherwise we will leak memory out of Netty's buffer pool. + // + // If any operation ever holds on to the buffer, it is the + // responsibility of that operation to retain the buffer and + // release it later. + ReferenceCountUtil.release(msg); + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + String stringifiedCause = ExceptionUtils.stringifyException(cause); + String msg = "Exception in server pipeline. Caused by: " + stringifiedCause; + + ByteBuf err = MessageSerializer.serializeServerFailure( + ctx.alloc(), + new RuntimeException(msg)); + + ctx.writeAndFlush(err).addListener(ChannelFutureListener.CLOSE); + } + + /** + * Task to execute the actual query against the {@link InternalKvState} instance. + */ + private static class AsyncKvStateQueryTask implements Runnable { + + private final ChannelHandlerContext ctx; + + private final KvStateRequest request; + + private final InternalKvState<?> kvState; + + private final KvStateRequestStats stats; + + private final long creationNanos; + + public AsyncKvStateQueryTask( + ChannelHandlerContext ctx, + KvStateRequest request, + InternalKvState<?> kvState, + KvStateRequestStats stats) { + + this.ctx = Objects.requireNonNull(ctx, "Channel handler context"); + this.request = Objects.requireNonNull(request, "State query"); + this.kvState = Objects.requireNonNull(kvState, "KvState"); + this.stats = Objects.requireNonNull(stats, "State query stats"); + this.creationNanos = System.nanoTime(); + } + + @Override + public void run() { + boolean success = false; + + try { + if (!ctx.channel().isActive()) { + return; + } + + // Query the KvState instance + byte[] serializedKeyAndNamespace = request.getSerializedKeyAndNamespace(); + byte[] serializedResult = kvState.getSerializedValue(serializedKeyAndNamespace); + + if (serializedResult != null) { + // We found some data, success! + ByteBuf buf = MessageSerializer.serializeKvStateRequestResult( + ctx.alloc(), + request.getRequestId(), + serializedResult); + + int highWatermark = ctx.channel().config().getWriteBufferHighWaterMark(); + + ChannelFuture write; + if (buf.readableBytes() <= highWatermark) { + write = ctx.writeAndFlush(buf); + } else { + write = ctx.writeAndFlush(new ChunkedByteBuf(buf, highWatermark)); + } + + write.addListener(new QueryResultWriteListener()); + + success = true; + } else { + // No data for the key/namespace. This is considered to be + // a failure. + ByteBuf unknownKey = MessageSerializer.serializeKvStateRequestFailure( + ctx.alloc(), + request.getRequestId(), + new UnknownKeyOrNamespace()); + + ctx.writeAndFlush(unknownKey); + } + } catch (Throwable t) { + try { + String stringifiedCause = ExceptionUtils.stringifyException(t); + String errMsg = "Failed to query state backend for query " + + request.getRequestId() + ". Caused by: " + stringifiedCause; + + ByteBuf err = MessageSerializer.serializeKvStateRequestFailure( + ctx.alloc(), request.getRequestId(), new RuntimeException(errMsg)); + + ctx.writeAndFlush(err); + } catch (IOException e) { + LOG.error("Failed to respond with the error after failed to query state backend", e); + } + } finally { + if (!success) { + stats.reportFailedRequest(); + } + } + } + + @Override + public String toString() { + return "AsyncKvStateQueryTask{" + + ", request=" + request + + ", creationNanos=" + creationNanos + + '}'; + } + + /** + * Callback after query result has been written. + * + * <p>Gathers stats and logs errors. + */ + private class QueryResultWriteListener implements ChannelFutureListener { + + @Override + public void operationComplete(ChannelFuture future) throws Exception { + long durationNanos = System.nanoTime() - creationNanos; + long durationMillis = TimeUnit.MILLISECONDS.convert(durationNanos, TimeUnit.NANOSECONDS); + + if (future.isSuccess()) { + stats.reportSuccessfulRequest(durationMillis); + } else { + if (LOG.isDebugEnabled()) { + LOG.debug("Query " + request + " failed after " + durationMillis + " ms", future.cause()); + } + + stats.reportFailedRequest(); + } + } + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/29a6e995/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/server/KvStateServerImpl.java ---------------------------------------------------------------------- diff --git a/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/server/KvStateServerImpl.java b/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/server/KvStateServerImpl.java new file mode 100644 index 0000000..4bf7e24 --- /dev/null +++ b/flink-queryable-state/flink-queryable-state-java/src/main/java/org/apache/flink/queryablestate/server/KvStateServerImpl.java @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.queryablestate.server; + +import org.apache.flink.queryablestate.messages.KvStateRequest; +import org.apache.flink.runtime.io.network.netty.NettyBufferPool; +import org.apache.flink.runtime.query.KvStateRegistry; +import org.apache.flink.runtime.query.KvStateServer; +import org.apache.flink.runtime.query.KvStateServerAddress; +import org.apache.flink.runtime.query.netty.KvStateRequestStats; +import org.apache.flink.util.Preconditions; + +import org.apache.flink.shaded.guava18.com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.apache.flink.shaded.netty4.io.netty.bootstrap.ServerBootstrap; +import org.apache.flink.shaded.netty4.io.netty.channel.Channel; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInitializer; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelOption; +import org.apache.flink.shaded.netty4.io.netty.channel.EventLoopGroup; +import org.apache.flink.shaded.netty4.io.netty.channel.nio.NioEventLoopGroup; +import org.apache.flink.shaded.netty4.io.netty.channel.socket.SocketChannel; +import org.apache.flink.shaded.netty4.io.netty.channel.socket.nio.NioServerSocketChannel; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.LengthFieldBasedFrameDecoder; +import org.apache.flink.shaded.netty4.io.netty.handler.stream.ChunkedWriteHandler; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; + +/** + * Netty-based server answering {@link KvStateRequest} messages. + * + * <p>Requests are handled by asynchronous query tasks (see {@link KvStateServerHandler.AsyncKvStateQueryTask}) + * that are executed by a separate query Thread pool. This pool is shared among + * all TCP connections. + * + * <p>The incoming pipeline looks as follows: + * <pre> + * Socket.read() -> LengthFieldBasedFrameDecoder -> KvStateServerHandler + * </pre> + * + * <p>Received binary messages are expected to contain a frame length field. Netty's + * {@link LengthFieldBasedFrameDecoder} is used to fully receive the frame before + * giving it to our {@link KvStateServerHandler}. + * + * <p>Connections are established and closed by the client. The server only + * closes the connection on a fatal failure that cannot be recovered. A + * server-side connection close is considered a failure by the client. + */ +public class KvStateServerImpl implements KvStateServer { + + private static final Logger LOG = LoggerFactory.getLogger(KvStateServer.class); + + /** Server config: low water mark. */ + private static final int LOW_WATER_MARK = 8 * 1024; + + /** Server config: high water mark. */ + private static final int HIGH_WATER_MARK = 32 * 1024; + + /** Netty's ServerBootstrap. */ + private final ServerBootstrap bootstrap; + + /** Query executor thread pool. */ + private final ExecutorService queryExecutor; + + /** Address of this server. */ + private KvStateServerAddress serverAddress; + + /** + * Creates the {@link KvStateServer}. + * + * <p>The server needs to be started via {@link #start()} in order to bind + * to the configured bind address. + * + * @param bindAddress Address to bind to + * @param bindPort Port to bind to. Pick random port if 0. + * @param numEventLoopThreads Number of event loop threads + * @param numQueryThreads Number of query threads + * @param kvStateRegistry KvStateRegistry to query for KvState instances + * @param stats Statistics tracker + */ + public KvStateServerImpl( + InetAddress bindAddress, + Integer bindPort, + Integer numEventLoopThreads, + Integer numQueryThreads, + KvStateRegistry kvStateRegistry, + KvStateRequestStats stats) { + + Preconditions.checkArgument(bindPort >= 0 && bindPort <= 65536, "Port " + bindPort + + " is out of valid port range (0-65536)."); + + Preconditions.checkArgument(numEventLoopThreads >= 1, "Non-positive number of event loop threads."); + Preconditions.checkArgument(numQueryThreads >= 1, "Non-positive number of query threads."); + + Preconditions.checkNotNull(kvStateRegistry, "KvStateRegistry"); + Preconditions.checkNotNull(stats, "KvStateRequestStats"); + + NettyBufferPool bufferPool = new NettyBufferPool(numEventLoopThreads); + + ThreadFactory threadFactory = new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("Flink KvStateServer EventLoop Thread %d") + .build(); + + NioEventLoopGroup nioGroup = new NioEventLoopGroup(numEventLoopThreads, threadFactory); + + queryExecutor = createQueryExecutor(numQueryThreads); + + // Shared between all channels + KvStateServerHandler serverHandler = new KvStateServerHandler( + kvStateRegistry, + queryExecutor, + stats); + + bootstrap = new ServerBootstrap() + // Bind address and port + .localAddress(bindAddress, bindPort) + // NIO server channels + .group(nioGroup) + .channel(NioServerSocketChannel.class) + // Server channel Options + .option(ChannelOption.ALLOCATOR, bufferPool) + // Child channel options + .childOption(ChannelOption.ALLOCATOR, bufferPool) + .childOption(ChannelOption.WRITE_BUFFER_HIGH_WATER_MARK, HIGH_WATER_MARK) + .childOption(ChannelOption.WRITE_BUFFER_LOW_WATER_MARK, LOW_WATER_MARK) + // See initializer for pipeline details + .childHandler(new KvStateServerChannelInitializer(serverHandler)); + } + + @Override + public void start() throws InterruptedException { + Channel channel = bootstrap.bind().sync().channel(); + + InetSocketAddress localAddress = (InetSocketAddress) channel.localAddress(); + serverAddress = new KvStateServerAddress(localAddress.getAddress(), localAddress.getPort()); + } + + @Override + public KvStateServerAddress getAddress() { + if (serverAddress == null) { + throw new IllegalStateException("KvStateServer not started yet."); + } + + return serverAddress; + } + + @Override + public void shutDown() { + if (bootstrap != null) { + EventLoopGroup group = bootstrap.group(); + if (group != null) { + group.shutdownGracefully(0, 10, TimeUnit.SECONDS); + } + } + + if (queryExecutor != null) { + queryExecutor.shutdown(); + } + + serverAddress = null; + } + + /** + * Creates a thread pool for the query execution. + * + * @param numQueryThreads Number of query threads. + * @return Thread pool for query execution + */ + private static ExecutorService createQueryExecutor(int numQueryThreads) { + ThreadFactory threadFactory = new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("Flink KvStateServer Query Thread %d") + .build(); + + return Executors.newFixedThreadPool(numQueryThreads, threadFactory); + } + + /** + * Channel pipeline initializer. + * + * <p>The request handler is shared, whereas the other handlers are created + * per channel. + */ + private static final class KvStateServerChannelInitializer extends ChannelInitializer<SocketChannel> { + + /** The shared request handler. */ + private final KvStateServerHandler sharedRequestHandler; + + /** + * Creates the channel pipeline initializer with the shared request handler. + * + * @param sharedRequestHandler Shared request handler. + */ + public KvStateServerChannelInitializer(KvStateServerHandler sharedRequestHandler) { + this.sharedRequestHandler = Preconditions.checkNotNull(sharedRequestHandler, "Request handler"); + } + + @Override + protected void initChannel(SocketChannel ch) throws Exception { + ch.pipeline() + .addLast(new ChunkedWriteHandler()) + .addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) + .addLast(sharedRequestHandler); + } + } + +} http://git-wip-us.apache.org/repos/asf/flink/blob/29a6e995/flink-queryable-state/flink-queryable-state-java/src/test/java/org/apache/flink/queryablestate/itcases/AbstractQueryableStateITCase.java ---------------------------------------------------------------------- diff --git a/flink-queryable-state/flink-queryable-state-java/src/test/java/org/apache/flink/queryablestate/itcases/AbstractQueryableStateITCase.java b/flink-queryable-state/flink-queryable-state-java/src/test/java/org/apache/flink/queryablestate/itcases/AbstractQueryableStateITCase.java new file mode 100644 index 0000000..a7f65f3 --- /dev/null +++ b/flink-queryable-state/flink-queryable-state-java/src/test/java/org/apache/flink/queryablestate/itcases/AbstractQueryableStateITCase.java @@ -0,0 +1,1128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.queryablestate.itcases; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.functions.FoldFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.state.FoldingStateDescriptor; +import org.apache.flink.api.common.state.ReducingStateDescriptor; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.configuration.ConfigConstants; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.queryablestate.UnknownKeyOrNamespace; +import org.apache.flink.queryablestate.client.QueryableStateClient; +import org.apache.flink.runtime.jobgraph.JobGraph; +import org.apache.flink.runtime.jobgraph.JobStatus; +import org.apache.flink.runtime.messages.JobManagerMessages; +import org.apache.flink.runtime.messages.JobManagerMessages.CancellationSuccess; +import org.apache.flink.runtime.minicluster.FlinkMiniCluster; +import org.apache.flink.runtime.state.AbstractStateBackend; +import org.apache.flink.runtime.state.CheckpointListener; +import org.apache.flink.runtime.state.VoidNamespace; +import org.apache.flink.runtime.state.VoidNamespaceTypeInfo; +import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.QueryableStateStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.TestLogger; + +import akka.actor.ActorSystem; +import akka.dispatch.Futures; +import akka.dispatch.OnSuccess; +import akka.dispatch.Recover; +import akka.pattern.Patterns; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicLongArray; + +import scala.concurrent.Await; +import scala.concurrent.Future; +import scala.concurrent.duration.Deadline; +import scala.concurrent.duration.FiniteDuration; +import scala.reflect.ClassTag$; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * Base class for queryable state integration tests with a configurable state backend. + */ +public abstract class AbstractQueryableStateITCase extends TestLogger { + + protected static final FiniteDuration TEST_TIMEOUT = new FiniteDuration(10000, TimeUnit.SECONDS); + private static final FiniteDuration QUERY_RETRY_DELAY = new FiniteDuration(100, TimeUnit.MILLISECONDS); + + protected static ActorSystem testActorSystem; + + /** + * State backend to use. + */ + protected AbstractStateBackend stateBackend; + + /** + * Shared between all the test. Make sure to have at least NUM_SLOTS + * available after your test finishes, e.g. cancel the job you submitted. + */ + protected static FlinkMiniCluster cluster; + + protected static int maxParallelism; + + @Before + public void setUp() throws Exception { + // NOTE: do not use a shared instance for all tests as the tests may brake + this.stateBackend = createStateBackend(); + + Assert.assertNotNull(cluster); + + maxParallelism = cluster.configuration().getInteger(ConfigConstants.LOCAL_NUMBER_TASK_MANAGER, 1) * + cluster.configuration().getInteger(ConfigConstants.TASK_MANAGER_NUM_TASK_SLOTS, 1); + } + + /** + * Creates a state backend instance which is used in the {@link #setUp()} method before each + * test case. + * + * @return a state backend instance for each unit test + */ + protected abstract AbstractStateBackend createStateBackend() throws Exception; + + /** + * Runs a simple topology producing random (key, 1) pairs at the sources (where + * number of keys is in fixed in range 0...numKeys). The records are keyed and + * a reducing queryable state instance is created, which sums up the records. + * + * <p>After submitting the job in detached mode, the QueryableStateCLient is used + * to query the counts of each key in rounds until all keys have non-zero counts. + */ + @Test + @SuppressWarnings("unchecked") + public void testQueryableState() throws Exception { + // Config + final Deadline deadline = TEST_TIMEOUT.fromNow(); + final int numKeys = 256; + + final QueryableStateClient client = new QueryableStateClient(cluster.configuration()); + + JobID jobId = null; + + try { + // + // Test program + // + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setStateBackend(stateBackend); + env.setParallelism(maxParallelism); + // Very important, because cluster is shared between tests and we + // don't explicitly check that all slots are available before + // submitting. + env.setRestartStrategy(RestartStrategies.fixedDelayRestart(Integer.MAX_VALUE, 1000)); + + DataStream<Tuple2<Integer, Long>> source = env + .addSource(new TestKeyRangeSource(numKeys)); + + // Reducing state + ReducingStateDescriptor<Tuple2<Integer, Long>> reducingState = new ReducingStateDescriptor<>( + "any-name", + new SumReduce(), + source.getType()); + + final String queryName = "hakuna-matata"; + + final QueryableStateStream<Integer, Tuple2<Integer, Long>> queryableState = + source.keyBy(new KeySelector<Tuple2<Integer, Long>, Integer>() { + private static final long serialVersionUID = 7143749578983540352L; + + @Override + public Integer getKey(Tuple2<Integer, Long> value) throws Exception { + return value.f0; + } + }).asQueryableState(queryName, reducingState); + + // Submit the job graph + JobGraph jobGraph = env.getStreamGraph().getJobGraph(); + cluster.submitJobDetached(jobGraph); + + // + // Start querying + // + jobId = jobGraph.getJobID(); + + final AtomicLongArray counts = new AtomicLongArray(numKeys); + + boolean allNonZero = false; + while (!allNonZero && deadline.hasTimeLeft()) { + allNonZero = true; + + final List<Future<Tuple2<Integer, Long>>> futures = new ArrayList<>(numKeys); + + for (int i = 0; i < numKeys; i++) { + final int key = i; + + if (counts.get(key) > 0) { + // Skip this one + continue; + } else { + allNonZero = false; + } + + Future<Tuple2<Integer, Long>> result = getKvStateWithRetries( + client, + jobId, + queryName, + key, + BasicTypeInfo.INT_TYPE_INFO, + reducingState, + QUERY_RETRY_DELAY, + false); + + result.onSuccess(new OnSuccess<Tuple2<Integer, Long>>() { + @Override + public void onSuccess(Tuple2<Integer, Long> result) throws Throwable { + counts.set(key, result.f1); + assertEquals("Key mismatch", key, result.f0.intValue()); + } + }, testActorSystem.dispatcher()); + + futures.add(result); + } + + Future<Iterable<Tuple2<Integer, Long>>> futureSequence = Futures.sequence( + futures, + testActorSystem.dispatcher()); + + Await.ready(futureSequence, deadline.timeLeft()); + } + + assertTrue("Not all keys are non-zero", allNonZero); + + // All should be non-zero + for (int i = 0; i < numKeys; i++) { + long count = counts.get(i); + assertTrue("Count at position " + i + " is " + count, count > 0); + } + } finally { + // Free cluster resources + if (jobId != null) { + Future<CancellationSuccess> cancellation = cluster + .getLeaderGateway(deadline.timeLeft()) + .ask(new JobManagerMessages.CancelJob(jobId), deadline.timeLeft()) + .mapTo(ClassTag$.MODULE$.<CancellationSuccess>apply(CancellationSuccess.class)); + + Await.ready(cancellation, deadline.timeLeft()); + } + + client.shutDown(); + } + } + + /** + * Tests that duplicate query registrations fail the job at the JobManager. + * + * <b>NOTE: </b> This test is only in the non-HA variant of the tests because + * in the HA mode we use the actual JM code which does not recognize the + * {@code NotifyWhenJobStatus} message. * + */ + @Test + public void testDuplicateRegistrationFailsJob() throws Exception { + final Deadline deadline = TEST_TIMEOUT.fromNow(); + final int numKeys = 256; + + JobID jobId = null; + + try { + // + // Test program + // + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setStateBackend(stateBackend); + env.setParallelism(maxParallelism); + // Very important, because cluster is shared between tests and we + // don't explicitly check that all slots are available before + // submitting. + env.setRestartStrategy(RestartStrategies.fixedDelayRestart(Integer.MAX_VALUE, 1000)); + + DataStream<Tuple2<Integer, Long>> source = env + .addSource(new TestKeyRangeSource(numKeys)); + + // Reducing state + ReducingStateDescriptor<Tuple2<Integer, Long>> reducingState = new ReducingStateDescriptor<>( + "any-name", + new SumReduce(), + source.getType()); + + final String queryName = "duplicate-me"; + + final QueryableStateStream<Integer, Tuple2<Integer, Long>> queryableState = + source.keyBy(new KeySelector<Tuple2<Integer, Long>, Integer>() { + private static final long serialVersionUID = -4126824763829132959L; + + @Override + public Integer getKey(Tuple2<Integer, Long> value) throws Exception { + return value.f0; + } + }).asQueryableState(queryName, reducingState); + + final QueryableStateStream<Integer, Tuple2<Integer, Long>> duplicate = + source.keyBy(new KeySelector<Tuple2<Integer, Long>, Integer>() { + private static final long serialVersionUID = -6265024000462809436L; + + @Override + public Integer getKey(Tuple2<Integer, Long> value) throws Exception { + return value.f0; + } + }).asQueryableState(queryName); + + // Submit the job graph + JobGraph jobGraph = env.getStreamGraph().getJobGraph(); + jobId = jobGraph.getJobID(); + + Future<TestingJobManagerMessages.JobStatusIs> failedFuture = cluster + .getLeaderGateway(deadline.timeLeft()) + .ask(new TestingJobManagerMessages.NotifyWhenJobStatus(jobId, JobStatus.FAILED), deadline.timeLeft()) + .mapTo(ClassTag$.MODULE$.<TestingJobManagerMessages.JobStatusIs>apply(TestingJobManagerMessages.JobStatusIs.class)); + + cluster.submitJobDetached(jobGraph); + + TestingJobManagerMessages.JobStatusIs jobStatus = Await.result(failedFuture, deadline.timeLeft()); + assertEquals(JobStatus.FAILED, jobStatus.state()); + + // Get the job and check the cause + JobManagerMessages.JobFound jobFound = Await.result( + cluster.getLeaderGateway(deadline.timeLeft()) + .ask(new JobManagerMessages.RequestJob(jobId), deadline.timeLeft()) + .mapTo(ClassTag$.MODULE$.<JobManagerMessages.JobFound>apply(JobManagerMessages.JobFound.class)), + deadline.timeLeft()); + + String failureCause = jobFound.executionGraph().getFailureCause().getExceptionAsString(); + + assertTrue("Not instance of SuppressRestartsException", failureCause.startsWith("org.apache.flink.runtime.execution.SuppressRestartsException")); + int causedByIndex = failureCause.indexOf("Caused by: "); + String subFailureCause = failureCause.substring(causedByIndex + "Caused by: ".length()); + assertTrue("Not caused by IllegalStateException", subFailureCause.startsWith("java.lang.IllegalStateException")); + assertTrue("Exception does not contain registration name", subFailureCause.contains(queryName)); + } finally { + // Free cluster resources + if (jobId != null) { + Future<CancellationSuccess> cancellation = cluster + .getLeaderGateway(deadline.timeLeft()) + .ask(new JobManagerMessages.CancelJob(jobId), deadline.timeLeft()) + .mapTo(ClassTag$.MODULE$.<CancellationSuccess>apply(CancellationSuccess.class)); + + Await.ready(cancellation, deadline.timeLeft()); + } + } + } + + /** + * Tests simple value state queryable state instance. Each source emits + * (subtaskIndex, 0)..(subtaskIndex, numElements) tuples, which are then + * queried. The tests succeeds after each subtask index is queried with + * value numElements (the latest element updated the state). + */ + @Test + public void testValueState() throws Exception { + // Config + final Deadline deadline = TEST_TIMEOUT.fromNow(); + + final int numElements = 1024; + + final QueryableStateClient client = new QueryableStateClient(cluster.configuration()); + + JobID jobId = null; + try { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setStateBackend(stateBackend); + env.setParallelism(maxParallelism); + // Very important, because cluster is shared between tests and we + // don't explicitly check that all slots are available before + // submitting. + env.setRestartStrategy(RestartStrategies.fixedDelayRestart(Integer.MAX_VALUE, 1000)); + + DataStream<Tuple2<Integer, Long>> source = env + .addSource(new TestAscendingValueSource(numElements)); + + // Value state + ValueStateDescriptor<Tuple2<Integer, Long>> valueState = new ValueStateDescriptor<>( + "any", + source.getType()); + + QueryableStateStream<Integer, Tuple2<Integer, Long>> queryableState = + source.keyBy(new KeySelector<Tuple2<Integer, Long>, Integer>() { + private static final long serialVersionUID = 7662520075515707428L; + + @Override + public Integer getKey(Tuple2<Integer, Long> value) throws Exception { + return value.f0; + } + }).asQueryableState("hakuna", valueState); + + // Submit the job graph + JobGraph jobGraph = env.getStreamGraph().getJobGraph(); + jobId = jobGraph.getJobID(); + + cluster.submitJobDetached(jobGraph); + + // Now query + long expected = numElements; + + executeQuery(deadline, client, jobId, "hakuna", valueState, expected); + } finally { + // Free cluster resources + if (jobId != null) { + Future<CancellationSuccess> cancellation = cluster + .getLeaderGateway(deadline.timeLeft()) + .ask(new JobManagerMessages.CancelJob(jobId), deadline.timeLeft()) + .mapTo(ClassTag$.MODULE$.<CancellationSuccess>apply(CancellationSuccess.class)); + + Await.ready(cancellation, deadline.timeLeft()); + } + + client.shutDown(); + } + } + + /** + * Similar tests as {@link #testValueState()} but before submitting the + * job, we already issue one request which fails. + */ + @Test + public void testQueryNonStartedJobState() throws Exception { + // Config + final Deadline deadline = TEST_TIMEOUT.fromNow(); + + final int numElements = 1024; + + final QueryableStateClient client = new QueryableStateClient(cluster.configuration()); + + JobID jobId = null; + try { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setStateBackend(stateBackend); + env.setParallelism(maxParallelism); + // Very important, because cluster is shared between tests and we + // don't explicitly check that all slots are available before + // submitting. + env.setRestartStrategy(RestartStrategies.fixedDelayRestart(Integer.MAX_VALUE, 1000)); + + DataStream<Tuple2<Integer, Long>> source = env + .addSource(new TestAscendingValueSource(numElements)); + + // Value state + ValueStateDescriptor<Tuple2<Integer, Long>> valueState = new ValueStateDescriptor<>( + "any", + source.getType(), + null); + + QueryableStateStream<Integer, Tuple2<Integer, Long>> queryableState = + source.keyBy(new KeySelector<Tuple2<Integer, Long>, Integer>() { + private static final long serialVersionUID = 7480503339992214681L; + + @Override + public Integer getKey(Tuple2<Integer, Long> value) throws Exception { + return value.f0; + } + }).asQueryableState("hakuna", valueState); + + // Submit the job graph + JobGraph jobGraph = env.getStreamGraph().getJobGraph(); + jobId = jobGraph.getJobID(); + + // Now query + long expected = numElements; + + // query once + client.getKvState( + jobId, + queryableState.getQueryableStateName(), + 0, + VoidNamespace.INSTANCE, + BasicTypeInfo.INT_TYPE_INFO, + VoidNamespaceTypeInfo.INSTANCE, + valueState); + + cluster.submitJobDetached(jobGraph); + + executeQuery(deadline, client, jobId, "hakuna", valueState, expected); + } finally { + // Free cluster resources + if (jobId != null) { + Future<CancellationSuccess> cancellation = cluster + .getLeaderGateway(deadline.timeLeft()) + .ask(new JobManagerMessages.CancelJob(jobId), deadline.timeLeft()) + .mapTo(ClassTag$.MODULE$.<CancellationSuccess>apply(CancellationSuccess.class)); + + Await.ready(cancellation, deadline.timeLeft()); + } + + client.shutDown(); + } + } + + /** + * Retry a query for state for keys between 0 and {@link #maxParallelism} until + * <tt>expected</tt> equals the value of the result tuple's second field. + */ + private void executeQuery( + final Deadline deadline, + final QueryableStateClient client, + final JobID jobId, + final String queryableStateName, + final StateDescriptor<?, Tuple2<Integer, Long>> stateDescriptor, + final long expected) throws Exception { + + for (int key = 0; key < maxParallelism; key++) { + boolean success = false; + while (deadline.hasTimeLeft() && !success) { + Future<Tuple2<Integer, Long>> future = getKvStateWithRetries(client, + jobId, + queryableStateName, + key, + BasicTypeInfo.INT_TYPE_INFO, + stateDescriptor, + QUERY_RETRY_DELAY, + false); + + Tuple2<Integer, Long> value = Await.result(future, deadline.timeLeft()); + + assertEquals("Key mismatch", key, value.f0.intValue()); + if (expected == value.f1) { + success = true; + } else { + // Retry + Thread.sleep(50); + } + } + + assertTrue("Did not succeed query", success); + } + } + + /** + * Retry a query for state for keys between 0 and {@link #maxParallelism} until + * <tt>expected</tt> equals the value of the result tuple's second field. + */ + private void executeQuery( + final Deadline deadline, + final QueryableStateClient client, + final JobID jobId, + final String queryableStateName, + final TypeSerializer<Tuple2<Integer, Long>> valueSerializer, + final long expected) throws Exception { + + for (int key = 0; key < maxParallelism; key++) { + boolean success = false; + while (deadline.hasTimeLeft() && !success) { + Future<Tuple2<Integer, Long>> future = getKvStateWithRetries(client, + jobId, + queryableStateName, + key, + BasicTypeInfo.INT_TYPE_INFO, + valueSerializer, + QUERY_RETRY_DELAY, + false); + + Tuple2<Integer, Long> value = Await.result(future, deadline.timeLeft()); + + assertEquals("Key mismatch", key, value.f0.intValue()); + if (expected == value.f1) { + success = true; + } else { + // Retry + Thread.sleep(50); + } + } + + assertTrue("Did not succeed query", success); + } + } + + /** + * Tests simple value state queryable state instance with a default value + * set. Each source emits (subtaskIndex, 0)..(subtaskIndex, numElements) + * tuples, the key is mapped to 1 but key 0 is queried which should throw + * a {@link UnknownKeyOrNamespace} exception. + * + * @throws UnknownKeyOrNamespace thrown due querying a non-existent key + */ + @Test(expected = UnknownKeyOrNamespace.class) + public void testValueStateDefault() throws + Exception, UnknownKeyOrNamespace { + + // Config + final Deadline deadline = TEST_TIMEOUT.fromNow(); + + final int numElements = 1024; + + final QueryableStateClient client = new QueryableStateClient(cluster.configuration()); + + JobID jobId = null; + try { + StreamExecutionEnvironment env = + StreamExecutionEnvironment.getExecutionEnvironment(); + env.setStateBackend(stateBackend); + env.setParallelism(maxParallelism); + // Very important, because cluster is shared between tests and we + // don't explicitly check that all slots are available before + // submitting. + env.setRestartStrategy(RestartStrategies + .fixedDelayRestart(Integer.MAX_VALUE, 1000)); + + DataStream<Tuple2<Integer, Long>> source = env + .addSource(new TestAscendingValueSource(numElements)); + + // Value state + ValueStateDescriptor<Tuple2<Integer, Long>> valueState = + new ValueStateDescriptor<>( + "any", + source.getType(), + Tuple2.of(0, 1337L)); + + // only expose key "1" + QueryableStateStream<Integer, Tuple2<Integer, Long>> + queryableState = + source.keyBy( + new KeySelector<Tuple2<Integer, Long>, Integer>() { + private static final long serialVersionUID = 4509274556892655887L; + + @Override + public Integer getKey( + Tuple2<Integer, Long> value) throws + Exception { + return 1; + } + }).asQueryableState("hakuna", valueState); + + // Submit the job graph + JobGraph jobGraph = env.getStreamGraph().getJobGraph(); + jobId = jobGraph.getJobID(); + + cluster.submitJobDetached(jobGraph); + + // Now query + int key = 0; + Future<Tuple2<Integer, Long>> future = getKvStateWithRetries(client, + jobId, + queryableState.getQueryableStateName(), + key, + BasicTypeInfo.INT_TYPE_INFO, + valueState, + QUERY_RETRY_DELAY, + true); + + Await.result(future, deadline.timeLeft()); + } finally { + // Free cluster resources + if (jobId != null) { + Future<CancellationSuccess> cancellation = cluster + .getLeaderGateway(deadline.timeLeft()) + .ask(new JobManagerMessages.CancelJob(jobId), + deadline.timeLeft()) + .mapTo(ClassTag$.MODULE$.<CancellationSuccess>apply( + CancellationSuccess.class)); + + Await.ready(cancellation, deadline.timeLeft()); + } + + client.shutDown(); + } + } + + /** + * Tests simple value state queryable state instance. Each source emits + * (subtaskIndex, 0)..(subtaskIndex, numElements) tuples, which are then + * queried. The tests succeeds after each subtask index is queried with + * value numElements (the latest element updated the state). + * + * <p>This is the same as the simple value state test, but uses the API shortcut. + */ + @Test + public void testValueStateShortcut() throws Exception { + // Config + final Deadline deadline = TEST_TIMEOUT.fromNow(); + + final int numElements = 1024; + + final QueryableStateClient client = new QueryableStateClient(cluster.configuration()); + + JobID jobId = null; + try { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setStateBackend(stateBackend); + env.setParallelism(maxParallelism); + // Very important, because cluster is shared between tests and we + // don't explicitly check that all slots are available before + // submitting. + env.setRestartStrategy(RestartStrategies.fixedDelayRestart(Integer.MAX_VALUE, 1000)); + + DataStream<Tuple2<Integer, Long>> source = env + .addSource(new TestAscendingValueSource(numElements)); + + // Value state shortcut + QueryableStateStream<Integer, Tuple2<Integer, Long>> queryableState = + source.keyBy(new KeySelector<Tuple2<Integer, Long>, Integer>() { + private static final long serialVersionUID = 9168901838808830068L; + + @Override + public Integer getKey(Tuple2<Integer, Long> value) throws Exception { + return value.f0; + } + }).asQueryableState("matata"); + + // Submit the job graph + JobGraph jobGraph = env.getStreamGraph().getJobGraph(); + jobId = jobGraph.getJobID(); + + cluster.submitJobDetached(jobGraph); + + // Now query + long expected = numElements; + + executeQuery(deadline, client, jobId, "matata", + queryableState.getValueSerializer(), expected); + } finally { + // Free cluster resources + if (jobId != null) { + Future<CancellationSuccess> cancellation = cluster + .getLeaderGateway(deadline.timeLeft()) + .ask(new JobManagerMessages.CancelJob(jobId), deadline.timeLeft()) + .mapTo(ClassTag$.MODULE$.<CancellationSuccess>apply(CancellationSuccess.class)); + + Await.ready(cancellation, deadline.timeLeft()); + } + + client.shutDown(); + } + } + + /** + * Tests simple folding state queryable state instance. Each source emits + * (subtaskIndex, 0)..(subtaskIndex, numElements) tuples, which are then + * queried. The folding state sums these up and maps them to Strings. The + * test succeeds after each subtask index is queried with result n*(n+1)/2 + * (as a String). + */ + @Test + public void testFoldingState() throws Exception { + // Config + final Deadline deadline = TEST_TIMEOUT.fromNow(); + + final int numElements = 1024; + + final QueryableStateClient client = new QueryableStateClient(cluster.configuration()); + + JobID jobId = null; + try { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setStateBackend(stateBackend); + env.setParallelism(maxParallelism); + // Very important, because cluster is shared between tests and we + // don't explicitly check that all slots are available before + // submitting. + env.setRestartStrategy(RestartStrategies.fixedDelayRestart(Integer.MAX_VALUE, 1000)); + + DataStream<Tuple2<Integer, Long>> source = env + .addSource(new TestAscendingValueSource(numElements)); + + // Folding state + FoldingStateDescriptor<Tuple2<Integer, Long>, String> foldingState = + new FoldingStateDescriptor<>( + "any", + "0", + new SumFold(), + StringSerializer.INSTANCE); + + QueryableStateStream<Integer, String> queryableState = + source.keyBy(new KeySelector<Tuple2<Integer, Long>, Integer>() { + private static final long serialVersionUID = -842809958106747539L; + + @Override + public Integer getKey(Tuple2<Integer, Long> value) throws Exception { + return value.f0; + } + }).asQueryableState("pumba", foldingState); + + // Submit the job graph + JobGraph jobGraph = env.getStreamGraph().getJobGraph(); + jobId = jobGraph.getJobID(); + + cluster.submitJobDetached(jobGraph); + + // Now query + String expected = Integer.toString(numElements * (numElements + 1) / 2); + + for (int key = 0; key < maxParallelism; key++) { + boolean success = false; + while (deadline.hasTimeLeft() && !success) { + Future<String> future = getKvStateWithRetries(client, + jobId, + queryableState.getQueryableStateName(), + key, + BasicTypeInfo.INT_TYPE_INFO, + foldingState, + QUERY_RETRY_DELAY, + false); + + String value = Await.result(future, deadline.timeLeft()); + if (expected.equals(value)) { + success = true; + } else { + // Retry + Thread.sleep(50); + } + } + + assertTrue("Did not succeed query", success); + } + } finally { + // Free cluster resources + if (jobId != null) { + Future<CancellationSuccess> cancellation = cluster + .getLeaderGateway(deadline.timeLeft()) + .ask(new JobManagerMessages.CancelJob(jobId), deadline.timeLeft()) + .mapTo(ClassTag$.MODULE$.<CancellationSuccess>apply(CancellationSuccess.class)); + + Await.ready(cancellation, deadline.timeLeft()); + } + + client.shutDown(); + } + } + + /** + * Tests simple reducing state queryable state instance. Each source emits + * (subtaskIndex, 0)..(subtaskIndex, numElements) tuples, which are then + * queried. The reducing state instance sums these up. The test succeeds + * after each subtask index is queried with result n*(n+1)/2. + */ + @Test + public void testReducingState() throws Exception { + // Config + final Deadline deadline = TEST_TIMEOUT.fromNow(); + + final int numElements = 1024; + + final QueryableStateClient client = new QueryableStateClient(cluster.configuration()); + + JobID jobId = null; + try { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setStateBackend(stateBackend); + env.setParallelism(maxParallelism); + // Very important, because cluster is shared between tests and we + // don't explicitly check that all slots are available before + // submitting. + env.setRestartStrategy(RestartStrategies.fixedDelayRestart(Integer.MAX_VALUE, 1000)); + + DataStream<Tuple2<Integer, Long>> source = env + .addSource(new TestAscendingValueSource(numElements)); + + // Reducing state + ReducingStateDescriptor<Tuple2<Integer, Long>> reducingState = + new ReducingStateDescriptor<>( + "any", + new SumReduce(), + source.getType()); + + QueryableStateStream<Integer, Tuple2<Integer, Long>> queryableState = + source.keyBy(new KeySelector<Tuple2<Integer, Long>, Integer>() { + private static final long serialVersionUID = 8470749712274833552L; + + @Override + public Integer getKey(Tuple2<Integer, Long> value) throws Exception { + return value.f0; + } + }).asQueryableState("jungle", reducingState); + + // Submit the job graph + JobGraph jobGraph = env.getStreamGraph().getJobGraph(); + jobId = jobGraph.getJobID(); + + cluster.submitJobDetached(jobGraph); + + // Wait until job is running + + // Now query + long expected = numElements * (numElements + 1) / 2; + + executeQuery(deadline, client, jobId, "jungle", reducingState, expected); + } finally { + // Free cluster resources + if (jobId != null) { + Future<CancellationSuccess> cancellation = cluster + .getLeaderGateway(deadline.timeLeft()) + .ask(new JobManagerMessages.CancelJob(jobId), deadline.timeLeft()) + .mapTo(ClassTag$.MODULE$.<CancellationSuccess>apply(CancellationSuccess.class)); + + Await.ready(cancellation, deadline.timeLeft()); + } + + client.shutDown(); + } + } + + private static <K, V> Future<V> getKvStateWithRetries( + final QueryableStateClient client, + final JobID jobId, + final String queryName, + final K key, + final TypeInformation<K> keyTypeInfo, + final TypeSerializer<V> valueTypeSerializer, + final FiniteDuration retryDelay, + final boolean failForUnknownKeyOrNamespace) { + + return client.getKvState(jobId, queryName, key, VoidNamespace.INSTANCE, keyTypeInfo, VoidNamespaceTypeInfo.INSTANCE, valueTypeSerializer) + .recoverWith(new Recover<Future<V>>() { + @Override + public Future<V> recover(Throwable failure) throws Throwable { + if (failure instanceof AssertionError) { + return Futures.failed(failure); + } else if (failForUnknownKeyOrNamespace && + (failure instanceof UnknownKeyOrNamespace)) { + return Futures.failed(failure); + } else { + // At startup some failures are expected + // due to races. Make sure that they don't + // fail this test. + return Patterns.after( + retryDelay, + testActorSystem.scheduler(), + testActorSystem.dispatcher(), + new Callable<Future<V>>() { + @Override + public Future<V> call() throws Exception { + return getKvStateWithRetries( + client, + jobId, + queryName, + key, + keyTypeInfo, + valueTypeSerializer, + retryDelay, + failForUnknownKeyOrNamespace); + } + }); + } + } + }, testActorSystem.dispatcher()); + + } + + private static <K, V> Future<V> getKvStateWithRetries( + final QueryableStateClient client, + final JobID jobId, + final String queryName, + final K key, + final TypeInformation<K> keyTypeInfo, + final StateDescriptor<?, V> stateDescriptor, + final FiniteDuration retryDelay, + final boolean failForUnknownKeyOrNamespace) { + + return client.getKvState(jobId, queryName, key, VoidNamespace.INSTANCE, keyTypeInfo, VoidNamespaceTypeInfo.INSTANCE, stateDescriptor) + .recoverWith(new Recover<Future<V>>() { + @Override + public Future<V> recover(Throwable failure) throws Throwable { + if (failure instanceof AssertionError) { + return Futures.failed(failure); + } else if (failForUnknownKeyOrNamespace && + (failure instanceof UnknownKeyOrNamespace)) { + return Futures.failed(failure); + } else { + // At startup some failures are expected + // due to races. Make sure that they don't + // fail this test. + return Patterns.after( + retryDelay, + testActorSystem.scheduler(), + testActorSystem.dispatcher(), + new Callable<Future<V>>() { + @Override + public Future<V> call() throws Exception { + return getKvStateWithRetries( + client, + jobId, + queryName, + key, + keyTypeInfo, + stateDescriptor, + retryDelay, + failForUnknownKeyOrNamespace); + } + }); + } + } + }, testActorSystem.dispatcher()); + } + + /** + * Test source producing (key, 0)..(key, maxValue) with key being the sub + * task index. + * + * <p>After all tuples have been emitted, the source waits to be cancelled + * and does not immediately finish. + */ + private static class TestAscendingValueSource extends RichParallelSourceFunction<Tuple2<Integer, Long>> { + + private static final long serialVersionUID = 1459935229498173245L; + + private final long maxValue; + private volatile boolean isRunning = true; + + TestAscendingValueSource(long maxValue) { + Preconditions.checkArgument(maxValue >= 0); + this.maxValue = maxValue; + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + } + + @Override + public void run(SourceContext<Tuple2<Integer, Long>> ctx) throws Exception { + // f0 => key + int key = getRuntimeContext().getIndexOfThisSubtask(); + Tuple2<Integer, Long> record = new Tuple2<>(key, 0L); + + long currentValue = 0; + while (isRunning && currentValue <= maxValue) { + synchronized (ctx.getCheckpointLock()) { + record.f1 = currentValue; + ctx.collect(record); + } + + currentValue++; + } + + while (isRunning) { + synchronized (this) { + this.wait(); + } + } + } + + @Override + public void cancel() { + isRunning = false; + + synchronized (this) { + this.notifyAll(); + } + } + + } + + /** + * Test source producing (key, 1) tuples with random key in key range (numKeys). + */ + protected static class TestKeyRangeSource extends RichParallelSourceFunction<Tuple2<Integer, Long>> + implements CheckpointListener { + private static final long serialVersionUID = -5744725196953582710L; + + private static final AtomicLong LATEST_CHECKPOINT_ID = new AtomicLong(); + private final int numKeys; + private final ThreadLocalRandom random = ThreadLocalRandom.current(); + private volatile boolean isRunning = true; + + TestKeyRangeSource(int numKeys) { + this.numKeys = numKeys; + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + if (getRuntimeContext().getIndexOfThisSubtask() == 0) { + LATEST_CHECKPOINT_ID.set(0); + } + } + + @Override + public void run(SourceContext<Tuple2<Integer, Long>> ctx) throws Exception { + // f0 => key + Tuple2<Integer, Long> record = new Tuple2<>(0, 1L); + + while (isRunning) { + synchronized (ctx.getCheckpointLock()) { + record.f0 = random.nextInt(numKeys); + ctx.collect(record); + } + // mild slow down + Thread.sleep(1); + } + } + + @Override + public void cancel() { + isRunning = false; + } + + @Override + public void notifyCheckpointComplete(long checkpointId) throws Exception { + if (getRuntimeContext().getIndexOfThisSubtask() == 0) { + LATEST_CHECKPOINT_ID.set(checkpointId); + } + } + } + + /** + * Test {@link FoldFunction} concatenating the already stored string with the long passed as argument. + */ + private static class SumFold implements FoldFunction<Tuple2<Integer, Long>, String> { + private static final long serialVersionUID = -6249227626701264599L; + + @Override + public String fold(String accumulator, Tuple2<Integer, Long> value) throws Exception { + long acc = Long.valueOf(accumulator); + acc += value.f1; + return Long.toString(acc); + } + } + + /** + * Test {@link ReduceFunction} summing up its two arguments. + */ + protected static class SumReduce implements ReduceFunction<Tuple2<Integer, Long>> { + private static final long serialVersionUID = -8651235077342052336L; + + @Override + public Tuple2<Integer, Long> reduce(Tuple2<Integer, Long> value1, Tuple2<Integer, Long> value2) throws Exception { + value1.f1 += value2.f1; + return value1; + } + } + +}
