http://git-wip-us.apache.org/repos/asf/flink/blob/0c771505/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/client/state/serialization/DataOutputSerializer.java ---------------------------------------------------------------------- diff --git a/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/client/state/serialization/DataOutputSerializer.java b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/client/state/serialization/DataOutputSerializer.java new file mode 100644 index 0000000..5811c91 --- /dev/null +++ b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/client/state/serialization/DataOutputSerializer.java @@ -0,0 +1,344 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.queryablestate.client.state.serialization; + +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.core.memory.MemoryUtils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.EOFException; +import java.io.IOException; +import java.io.UTFDataFormatException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Arrays; + +/** + * A simple and efficient serializer for the {@link java.io.DataOutput} interface. + * + * <p><b>THIS WAS COPIED FROM RUNTIME SO THAT WE AVOID THE DEPENDENCY.</b> + */ +public class DataOutputSerializer implements DataOutputView { + + private static final Logger LOG = LoggerFactory.getLogger(DataOutputSerializer.class); + + private static final int PRUNE_BUFFER_THRESHOLD = 5 * 1024 * 1024; + + // ------------------------------------------------------------------------ + + private final byte[] startBuffer; + + private byte[] buffer; + + private int position; + + private ByteBuffer wrapper; + + // ------------------------------------------------------------------------ + + public DataOutputSerializer(int startSize) { + if (startSize < 1) { + throw new IllegalArgumentException(); + } + + this.startBuffer = new byte[startSize]; + this.buffer = this.startBuffer; + this.wrapper = ByteBuffer.wrap(buffer); + } + + public ByteBuffer wrapAsByteBuffer() { + this.wrapper.position(0); + this.wrapper.limit(this.position); + return this.wrapper; + } + + public byte[] getByteArray() { + return buffer; + } + + public byte[] getCopyOfBuffer() { + return Arrays.copyOf(buffer, position); + } + + public void clear() { + this.position = 0; + } + + public int length() { + return this.position; + } + + public void pruneBuffer() { + if (this.buffer.length > PRUNE_BUFFER_THRESHOLD) { + if (LOG.isDebugEnabled()) { + LOG.debug("Releasing serialization buffer of " + this.buffer.length + " bytes."); + } + + this.buffer = this.startBuffer; + this.wrapper = ByteBuffer.wrap(this.buffer); + } + } + + @Override + public String toString() { + return String.format("[pos=%d cap=%d]", this.position, this.buffer.length); + } + + // ---------------------------------------------------------------------------------------- + // Data Output + // ---------------------------------------------------------------------------------------- + + @Override + public void write(int b) throws IOException { + if (this.position >= this.buffer.length) { + resize(1); + } + this.buffer[this.position++] = (byte) (b & 0xff); + } + + @Override + public void write(byte[] b) throws IOException { + write(b, 0, b.length); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + if (len < 0 || off > b.length - len) { + throw new ArrayIndexOutOfBoundsException(); + } + if (this.position > this.buffer.length - len) { + resize(len); + } + System.arraycopy(b, off, this.buffer, this.position, len); + this.position += len; + } + + @Override + public void writeBoolean(boolean v) throws IOException { + write(v ? 1 : 0); + } + + @Override + public void writeByte(int v) throws IOException { + write(v); + } + + @Override + public void writeBytes(String s) throws IOException { + final int sLen = s.length(); + if (this.position >= this.buffer.length - sLen) { + resize(sLen); + } + + for (int i = 0; i < sLen; i++) { + writeByte(s.charAt(i)); + } + this.position += sLen; + } + + @Override + public void writeChar(int v) throws IOException { + if (this.position >= this.buffer.length - 1) { + resize(2); + } + this.buffer[this.position++] = (byte) (v >> 8); + this.buffer[this.position++] = (byte) v; + } + + @Override + public void writeChars(String s) throws IOException { + final int sLen = s.length(); + if (this.position >= this.buffer.length - 2 * sLen) { + resize(2 * sLen); + } + + for (int i = 0; i < sLen; i++) { + writeChar(s.charAt(i)); + } + } + + @Override + public void writeDouble(double v) throws IOException { + writeLong(Double.doubleToLongBits(v)); + } + + @Override + public void writeFloat(float v) throws IOException { + writeInt(Float.floatToIntBits(v)); + } + + @SuppressWarnings("restriction") + @Override + public void writeInt(int v) throws IOException { + if (this.position >= this.buffer.length - 3) { + resize(4); + } + if (LITTLE_ENDIAN) { + v = Integer.reverseBytes(v); + } + UNSAFE.putInt(this.buffer, BASE_OFFSET + this.position, v); + this.position += 4; + } + + @SuppressWarnings("restriction") + @Override + public void writeLong(long v) throws IOException { + if (this.position >= this.buffer.length - 7) { + resize(8); + } + if (LITTLE_ENDIAN) { + v = Long.reverseBytes(v); + } + UNSAFE.putLong(this.buffer, BASE_OFFSET + this.position, v); + this.position += 8; + } + + @Override + public void writeShort(int v) throws IOException { + if (this.position >= this.buffer.length - 1) { + resize(2); + } + this.buffer[this.position++] = (byte) ((v >>> 8) & 0xff); + this.buffer[this.position++] = (byte) ((v >>> 0) & 0xff); + } + + @Override + public void writeUTF(String str) throws IOException { + int strlen = str.length(); + int utflen = 0; + int c; + + /* use charAt instead of copying String to char array */ + for (int i = 0; i < strlen; i++) { + c = str.charAt(i); + if ((c >= 0x0001) && (c <= 0x007F)) { + utflen++; + } else if (c > 0x07FF) { + utflen += 3; + } else { + utflen += 2; + } + } + + if (utflen > 65535) { + throw new UTFDataFormatException("Encoded string is too long: " + utflen); + } + else if (this.position > this.buffer.length - utflen - 2) { + resize(utflen + 2); + } + + byte[] bytearr = this.buffer; + int count = this.position; + + bytearr[count++] = (byte) ((utflen >>> 8) & 0xFF); + bytearr[count++] = (byte) ((utflen >>> 0) & 0xFF); + + int i = 0; + for (i = 0; i < strlen; i++) { + c = str.charAt(i); + if (!((c >= 0x0001) && (c <= 0x007F))) { + break; + } + bytearr[count++] = (byte) c; + } + + for (; i < strlen; i++) { + c = str.charAt(i); + if ((c >= 0x0001) && (c <= 0x007F)) { + bytearr[count++] = (byte) c; + + } else if (c > 0x07FF) { + bytearr[count++] = (byte) (0xE0 | ((c >> 12) & 0x0F)); + bytearr[count++] = (byte) (0x80 | ((c >> 6) & 0x3F)); + bytearr[count++] = (byte) (0x80 | ((c >> 0) & 0x3F)); + } else { + bytearr[count++] = (byte) (0xC0 | ((c >> 6) & 0x1F)); + bytearr[count++] = (byte) (0x80 | ((c >> 0) & 0x3F)); + } + } + + this.position = count; + } + + private void resize(int minCapacityAdd) throws IOException { + int newLen = Math.max(this.buffer.length * 2, this.buffer.length + minCapacityAdd); + byte[] nb; + try { + nb = new byte[newLen]; + } + catch (NegativeArraySizeException e) { + throw new IOException("Serialization failed because the record length would exceed 2GB (max addressable array size in Java)."); + } + catch (OutOfMemoryError e) { + // this was too large to allocate, try the smaller size (if possible) + if (newLen > this.buffer.length + minCapacityAdd) { + newLen = this.buffer.length + minCapacityAdd; + try { + nb = new byte[newLen]; + } + catch (OutOfMemoryError ee) { + // still not possible. give an informative exception message that reports the size + throw new IOException("Failed to serialize element. Serialized size (> " + + newLen + " bytes) exceeds JVM heap space", ee); + } + } else { + throw new IOException("Failed to serialize element. Serialized size (> " + + newLen + " bytes) exceeds JVM heap space", e); + } + } + + System.arraycopy(this.buffer, 0, nb, 0, this.position); + this.buffer = nb; + this.wrapper = ByteBuffer.wrap(this.buffer); + } + + @Override + public void skipBytesToWrite(int numBytes) throws IOException { + if (buffer.length - this.position < numBytes){ + throw new EOFException("Could not skip " + numBytes + " bytes."); + } + + this.position += numBytes; + } + + @Override + public void write(DataInputView source, int numBytes) throws IOException { + if (buffer.length - this.position < numBytes){ + throw new EOFException("Could not write " + numBytes + " bytes. Buffer overflow."); + } + + source.readFully(this.buffer, this.position, numBytes); + this.position += numBytes; + } + + // ------------------------------------------------------------------------ + // Utilities + // ------------------------------------------------------------------------ + + @SuppressWarnings("restriction") + private static final sun.misc.Unsafe UNSAFE = MemoryUtils.UNSAFE; + + @SuppressWarnings("restriction") + private static final long BASE_OFFSET = UNSAFE.arrayBaseOffset(byte[].class); + + private static final boolean LITTLE_ENDIAN = (MemoryUtils.NATIVE_BYTE_ORDER == ByteOrder.LITTLE_ENDIAN); +}
http://git-wip-us.apache.org/repos/asf/flink/blob/0c771505/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/client/state/serialization/KvStateSerializer.java ---------------------------------------------------------------------- diff --git a/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/client/state/serialization/KvStateSerializer.java b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/client/state/serialization/KvStateSerializer.java new file mode 100644 index 0000000..4c69483 --- /dev/null +++ b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/client/state/serialization/KvStateSerializer.java @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.queryablestate.client.state.serialization; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.tuple.Tuple2; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Serialization and deserialization the different state types and namespaces. + */ +public final class KvStateSerializer { + + // ------------------------------------------------------------------------ + // Generic serialization utils + // ------------------------------------------------------------------------ + + /** + * Serializes the key and namespace into a {@link ByteBuffer}. + * + * <p>The serialized format matches the RocksDB state backend key format, i.e. + * the key and namespace don't have to be deserialized for RocksDB lookups. + * + * @param key Key to serialize + * @param keySerializer Serializer for the key + * @param namespace Namespace to serialize + * @param namespaceSerializer Serializer for the namespace + * @param <K> Key type + * @param <N> Namespace type + * @return Buffer holding the serialized key and namespace + * @throws IOException Serialization errors are forwarded + */ + public static <K, N> byte[] serializeKeyAndNamespace( + K key, + TypeSerializer<K> keySerializer, + N namespace, + TypeSerializer<N> namespaceSerializer) throws IOException { + + DataOutputSerializer dos = new DataOutputSerializer(32); + + keySerializer.serialize(key, dos); + dos.writeByte(42); + namespaceSerializer.serialize(namespace, dos); + + return dos.getCopyOfBuffer(); + } + + /** + * Deserializes the key and namespace into a {@link Tuple2}. + * + * @param serializedKeyAndNamespace Serialized key and namespace + * @param keySerializer Serializer for the key + * @param namespaceSerializer Serializer for the namespace + * @param <K> Key type + * @param <N> Namespace + * @return Tuple2 holding deserialized key and namespace + * @throws IOException if the deserialization fails for any reason + */ + public static <K, N> Tuple2<K, N> deserializeKeyAndNamespace( + byte[] serializedKeyAndNamespace, + TypeSerializer<K> keySerializer, + TypeSerializer<N> namespaceSerializer) throws IOException { + + DataInputDeserializer dis = new DataInputDeserializer( + serializedKeyAndNamespace, + 0, + serializedKeyAndNamespace.length); + + try { + K key = keySerializer.deserialize(dis); + byte magicNumber = dis.readByte(); + if (magicNumber != 42) { + throw new IOException("Unexpected magic number " + magicNumber + "."); + } + N namespace = namespaceSerializer.deserialize(dis); + + if (dis.available() > 0) { + throw new IOException("Unconsumed bytes in the serialized key and namespace."); + } + + return new Tuple2<>(key, namespace); + } catch (IOException e) { + throw new IOException("Unable to deserialize key " + + "and namespace. This indicates a mismatch in the key/namespace " + + "serializers used by the KvState instance and this access.", e); + } + } + + /** + * Serializes the value with the given serializer. + * + * @param value Value of type T to serialize + * @param serializer Serializer for T + * @param <T> Type of the value + * @return Serialized value or <code>null</code> if value <code>null</code> + * @throws IOException On failure during serialization + */ + public static <T> byte[] serializeValue(T value, TypeSerializer<T> serializer) throws IOException { + if (value != null) { + // Serialize + DataOutputSerializer dos = new DataOutputSerializer(32); + serializer.serialize(value, dos); + return dos.getCopyOfBuffer(); + } else { + return null; + } + } + + /** + * Deserializes the value with the given serializer. + * + * @param serializedValue Serialized value of type T + * @param serializer Serializer for T + * @param <T> Type of the value + * @return Deserialized value or <code>null</code> if the serialized value + * is <code>null</code> + * @throws IOException On failure during deserialization + */ + public static <T> T deserializeValue(byte[] serializedValue, TypeSerializer<T> serializer) throws IOException { + if (serializedValue == null) { + return null; + } else { + final DataInputDeserializer deser = new DataInputDeserializer( + serializedValue, 0, serializedValue.length); + final T value = serializer.deserialize(deser); + if (deser.available() > 0) { + throw new IOException( + "Unconsumed bytes in the deserialized value. " + + "This indicates a mismatch in the value serializers " + + "used by the KvState instance and this access."); + } + return value; + } + } + + /** + * Deserializes all values with the given serializer. + * + * @param serializedValue Serialized value of type List<T> + * @param serializer Serializer for T + * @param <T> Type of the value + * @return Deserialized list or <code>null</code> if the serialized value + * is <code>null</code> + * @throws IOException On failure during deserialization + */ + public static <T> List<T> deserializeList(byte[] serializedValue, TypeSerializer<T> serializer) throws IOException { + if (serializedValue != null) { + final DataInputDeserializer in = new DataInputDeserializer( + serializedValue, 0, serializedValue.length); + + try { + final List<T> result = new ArrayList<>(); + while (in.available() > 0) { + result.add(serializer.deserialize(in)); + + // The expected binary format has a single byte separator. We + // want a consistent binary format in order to not need any + // special casing during deserialization. A "cleaner" format + // would skip this extra byte, but would require a memory copy + // for RocksDB, which stores the data serialized in this way + // for lists. + if (in.available() > 0) { + in.readByte(); + } + } + + return result; + } catch (IOException e) { + throw new IOException( + "Unable to deserialize value. " + + "This indicates a mismatch in the value serializers " + + "used by the KvState instance and this access.", e); + } + } else { + return null; + } + } + + /** + * Serializes all values of the Iterable with the given serializer. + * + * @param entries Key-value pairs to serialize + * @param keySerializer Serializer for UK + * @param valueSerializer Serializer for UV + * @param <UK> Type of the keys + * @param <UV> Type of the values + * @return Serialized values or <code>null</code> if values <code>null</code> or empty + * @throws IOException On failure during serialization + */ + public static <UK, UV> byte[] serializeMap(Iterable<Map.Entry<UK, UV>> entries, TypeSerializer<UK> keySerializer, TypeSerializer<UV> valueSerializer) throws IOException { + if (entries != null) { + // Serialize + DataOutputSerializer dos = new DataOutputSerializer(32); + + for (Map.Entry<UK, UV> entry : entries) { + keySerializer.serialize(entry.getKey(), dos); + + if (entry.getValue() == null) { + dos.writeBoolean(true); + } else { + dos.writeBoolean(false); + valueSerializer.serialize(entry.getValue(), dos); + } + } + + return dos.getCopyOfBuffer(); + } else { + return null; + } + } + + /** + * Deserializes all kv pairs with the given serializer. + * + * @param serializedValue Serialized value of type Map<UK, UV> + * @param keySerializer Serializer for UK + * @param valueSerializer Serializer for UV + * @param <UK> Type of the key + * @param <UV> Type of the value. + * @return Deserialized map or <code>null</code> if the serialized value + * is <code>null</code> + * @throws IOException On failure during deserialization + */ + public static <UK, UV> Map<UK, UV> deserializeMap(byte[] serializedValue, TypeSerializer<UK> keySerializer, TypeSerializer<UV> valueSerializer) throws IOException { + if (serializedValue != null) { + DataInputDeserializer in = new DataInputDeserializer(serializedValue, 0, serializedValue.length); + + Map<UK, UV> result = new HashMap<>(); + while (in.available() > 0) { + UK key = keySerializer.deserialize(in); + + boolean isNull = in.readBoolean(); + UV value = isNull ? null : valueSerializer.deserialize(in); + + result.put(key, value); + } + + return result; + } else { + return null; + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/0c771505/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/exceptions/UnknownJobManagerException.java ---------------------------------------------------------------------- diff --git a/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/exceptions/UnknownJobManagerException.java b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/exceptions/UnknownJobManagerException.java new file mode 100644 index 0000000..19063c2 --- /dev/null +++ b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/exceptions/UnknownJobManagerException.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.queryablestate.exceptions; + +import org.apache.flink.annotation.Internal; + +/** + * Exception to fail Future if the Task Manager on which the + * {@code Client Proxy} is running on, does not know the active + * Job Manager. + */ +@Internal +public class UnknownJobManagerException extends Exception { + + private static final long serialVersionUID = 9092442511708951209L; + + public UnknownJobManagerException() { + super("Unknown JobManager. Either the JobManager has not registered yet or has lost leadership."); + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/0c771505/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/exceptions/UnknownKeyOrNamespaceException.java ---------------------------------------------------------------------- diff --git a/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/exceptions/UnknownKeyOrNamespaceException.java b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/exceptions/UnknownKeyOrNamespaceException.java new file mode 100644 index 0000000..08e3324 --- /dev/null +++ b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/exceptions/UnknownKeyOrNamespaceException.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.queryablestate.exceptions; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.queryablestate.network.BadRequestException; + +/** + * Thrown if the KvState does not hold any state for the given key or namespace. + */ +@Internal +public class UnknownKeyOrNamespaceException extends BadRequestException { + + private static final long serialVersionUID = 1L; + + /** + * Creates the exception. + * @param serverName the name of the server that threw the exception. + */ + public UnknownKeyOrNamespaceException(String serverName) { + super(serverName, "No state for the specified key/namespace."); + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/0c771505/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/exceptions/UnknownKvStateIdException.java ---------------------------------------------------------------------- diff --git a/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/exceptions/UnknownKvStateIdException.java b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/exceptions/UnknownKvStateIdException.java new file mode 100644 index 0000000..81ea177 --- /dev/null +++ b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/exceptions/UnknownKvStateIdException.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.queryablestate.exceptions; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.queryablestate.KvStateID; +import org.apache.flink.queryablestate.network.BadRequestException; +import org.apache.flink.util.Preconditions; + +/** + * Thrown if no KvState with the given ID cannot found by the server handler. + */ +@Internal +public class UnknownKvStateIdException extends BadRequestException { + + private static final long serialVersionUID = 1L; + + /** + * Creates the exception. + * @param serverName the name of the server that threw the exception. + * @param kvStateId the state id for which no state was found. + */ + public UnknownKvStateIdException(String serverName, KvStateID kvStateId) { + super(serverName, "No registered state with ID " + Preconditions.checkNotNull(kvStateId) + '.'); + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/0c771505/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/exceptions/UnknownKvStateKeyGroupLocationException.java ---------------------------------------------------------------------- diff --git a/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/exceptions/UnknownKvStateKeyGroupLocationException.java b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/exceptions/UnknownKvStateKeyGroupLocationException.java new file mode 100644 index 0000000..d8d34f7 --- /dev/null +++ b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/exceptions/UnknownKvStateKeyGroupLocationException.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.queryablestate.exceptions; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.queryablestate.network.BadRequestException; + +/** + * Exception thrown if there is no location information available for the given key group. + */ +@Internal +public class UnknownKvStateKeyGroupLocationException extends BadRequestException { + + private static final long serialVersionUID = 1L; + + /** + * Creates the exception. + * @param serverName the name of the server that threw the exception. + */ + public UnknownKvStateKeyGroupLocationException(String serverName) { + super(serverName, "Unknown key-group location."); + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/0c771505/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/messages/KvStateRequest.java ---------------------------------------------------------------------- diff --git a/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/messages/KvStateRequest.java b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/messages/KvStateRequest.java new file mode 100644 index 0000000..8169d48 --- /dev/null +++ b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/messages/KvStateRequest.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.queryablestate.messages; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.JobID; +import org.apache.flink.queryablestate.network.messages.MessageBody; +import org.apache.flink.queryablestate.network.messages.MessageDeserializer; +import org.apache.flink.util.Preconditions; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; + +import java.nio.ByteBuffer; +import java.util.Arrays; + +/** + * The request to be sent by the {@link org.apache.flink.queryablestate.client.QueryableStateClient + * Queryable State Client} to the Client Proxy requesting a given state. + */ +@Internal +public class KvStateRequest extends MessageBody { + + private final JobID jobId; + private final String stateName; + private final int keyHashCode; + private final byte[] serializedKeyAndNamespace; + + public KvStateRequest( + final JobID jobId, + final String stateName, + final int keyHashCode, + final byte[] serializedKeyAndNamespace) { + + this.jobId = Preconditions.checkNotNull(jobId); + this.stateName = Preconditions.checkNotNull(stateName); + this.keyHashCode = keyHashCode; + this.serializedKeyAndNamespace = Preconditions.checkNotNull(serializedKeyAndNamespace); + } + + public JobID getJobId() { + return jobId; + } + + public String getStateName() { + return stateName; + } + + public int getKeyHashCode() { + return keyHashCode; + } + + public byte[] getSerializedKeyAndNamespace() { + return serializedKeyAndNamespace; + } + + @Override + public byte[] serialize() { + + byte[] serializedStateName = stateName.getBytes(); + + // JobID + stateName + sizeOf(stateName) + hashCode + keyAndNamespace + sizeOf(keyAndNamespace) + final int size = + JobID.SIZE + + serializedStateName.length + Integer.BYTES + + Integer.BYTES + + serializedKeyAndNamespace.length + Integer.BYTES; + + return ByteBuffer.allocate(size) + .putLong(jobId.getLowerPart()) + .putLong(jobId.getUpperPart()) + .putInt(serializedStateName.length) + .put(serializedStateName) + .putInt(keyHashCode) + .putInt(serializedKeyAndNamespace.length) + .put(serializedKeyAndNamespace) + .array(); + } + + @Override + public String toString() { + return "KvStateRequest{" + + "jobId=" + jobId + + ", stateName='" + stateName + '\'' + + ", keyHashCode=" + keyHashCode + + ", serializedKeyAndNamespace=" + Arrays.toString(serializedKeyAndNamespace) + + '}'; + } + + /** + * A {@link MessageDeserializer deserializer} for {@link KvStateRequest}. + */ + public static class KvStateRequestDeserializer implements MessageDeserializer<KvStateRequest> { + + @Override + public KvStateRequest deserializeMessage(ByteBuf buf) { + JobID jobId = new JobID(buf.readLong(), buf.readLong()); + + int statenameLength = buf.readInt(); + Preconditions.checkArgument(statenameLength >= 0, + "Negative length for state name. " + + "This indicates a serialization error."); + + String stateName = ""; + if (statenameLength > 0) { + byte[] name = new byte[statenameLength]; + buf.readBytes(name); + stateName = new String(name); + } + + int keyHashCode = buf.readInt(); + + int knamespaceLength = buf.readInt(); + Preconditions.checkArgument(knamespaceLength >= 0, + "Negative length for key and namespace. " + + "This indicates a serialization error."); + + byte[] serializedKeyAndNamespace = new byte[knamespaceLength]; + if (knamespaceLength > 0) { + buf.readBytes(serializedKeyAndNamespace); + } + return new KvStateRequest(jobId, stateName, keyHashCode, serializedKeyAndNamespace); + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/0c771505/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/messages/KvStateResponse.java ---------------------------------------------------------------------- diff --git a/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/messages/KvStateResponse.java b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/messages/KvStateResponse.java new file mode 100644 index 0000000..6bf14a7 --- /dev/null +++ b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/messages/KvStateResponse.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.queryablestate.messages; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.queryablestate.network.messages.MessageBody; +import org.apache.flink.queryablestate.network.messages.MessageDeserializer; +import org.apache.flink.util.Preconditions; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; + +import java.nio.ByteBuffer; + +/** + * The response containing the (serialized) state sent by the {@code State Server} to the {@code Client Proxy}, + * and then forwarded by the proxy to the original + * {@link org.apache.flink.queryablestate.client.QueryableStateClient Queryable State Client}. + */ +@Internal +public class KvStateResponse extends MessageBody { + + private final byte[] content; + + public KvStateResponse(final byte[] content) { + this.content = Preconditions.checkNotNull(content); + } + + public byte[] getContent() { + return content; + } + + @Override + public byte[] serialize() { + final int size = Integer.BYTES + content.length; + return ByteBuffer.allocate(size) + .putInt(content.length) + .put(content) + .array(); + } + + /** + * A {@link MessageDeserializer deserializer} for {@link KvStateResponseDeserializer}. + */ + public static class KvStateResponseDeserializer implements MessageDeserializer<KvStateResponse> { + + @Override + public KvStateResponse deserializeMessage(ByteBuf buf) { + int length = buf.readInt(); + Preconditions.checkArgument(length >= 0, + "Negative length for state content. " + + "This indicates a serialization error."); + byte[] content = new byte[length]; + buf.readBytes(content); + + return new KvStateResponse(content); + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/0c771505/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/network/AbstractServerBase.java ---------------------------------------------------------------------- diff --git a/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/network/AbstractServerBase.java b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/network/AbstractServerBase.java new file mode 100644 index 0000000..487020a --- /dev/null +++ b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/network/AbstractServerBase.java @@ -0,0 +1,308 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.queryablestate.network; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.queryablestate.network.messages.MessageBody; +import org.apache.flink.util.FlinkRuntimeException; +import org.apache.flink.util.Preconditions; + +import org.apache.flink.shaded.guava18.com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.apache.flink.shaded.netty4.io.netty.bootstrap.ServerBootstrap; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFuture; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInitializer; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelOption; +import org.apache.flink.shaded.netty4.io.netty.channel.EventLoopGroup; +import org.apache.flink.shaded.netty4.io.netty.channel.nio.NioEventLoopGroup; +import org.apache.flink.shaded.netty4.io.netty.channel.socket.SocketChannel; +import org.apache.flink.shaded.netty4.io.netty.channel.socket.nio.NioServerSocketChannel; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.LengthFieldBasedFrameDecoder; +import org.apache.flink.shaded.netty4.io.netty.handler.stream.ChunkedWriteHandler; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.BindException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; + +/** + * The base class for every server in the queryable state module. + * It is using pure netty to send and receive messages of type {@link MessageBody}. + * + * @param <REQ> the type of request the server expects to receive. + * @param <RESP> the type of response the server will send. + */ +@Internal +public abstract class AbstractServerBase<REQ extends MessageBody, RESP extends MessageBody> { + + private static final Logger LOG = LoggerFactory.getLogger(AbstractServerBase.class); + + /** AbstractServerBase config: low water mark. */ + private static final int LOW_WATER_MARK = 8 * 1024; + + /** AbstractServerBase config: high water mark. */ + private static final int HIGH_WATER_MARK = 32 * 1024; + + /** The name of the server, useful for debugging. */ + private final String serverName; + + /** The {@link InetAddress address} to listen to. */ + private final InetAddress bindAddress; + + /** A port range on which to try to connect. */ + private final Set<Integer> bindPortRange; + + /** The number of threads to be allocated to the event loop. */ + private final int numEventLoopThreads; + + /** The number of threads to be used for query serving. */ + private final int numQueryThreads; + + /** Netty's ServerBootstrap. */ + private ServerBootstrap bootstrap; + + /** Query executor thread pool. */ + private ExecutorService queryExecutor; + + /** Address of this server. */ + private InetSocketAddress serverAddress; + + /** The handler used for the incoming messages. */ + private AbstractServerHandler<REQ, RESP> handler; + + /** + * Creates the {@link AbstractServerBase}. + * + * <p>The server needs to be started via {@link #start()}. + * + * @param serverName the name of the server + * @param bindAddress address to bind to + * @param bindPortIterator port to bind to + * @param numEventLoopThreads number of event loop threads + */ + protected AbstractServerBase( + final String serverName, + final InetAddress bindAddress, + final Iterator<Integer> bindPortIterator, + final Integer numEventLoopThreads, + final Integer numQueryThreads) { + + Preconditions.checkNotNull(bindPortIterator); + Preconditions.checkArgument(numEventLoopThreads >= 1, "Non-positive number of event loop threads."); + Preconditions.checkArgument(numQueryThreads >= 1, "Non-positive number of query threads."); + + this.serverName = Preconditions.checkNotNull(serverName); + this.bindAddress = Preconditions.checkNotNull(bindAddress); + this.numEventLoopThreads = numEventLoopThreads; + this.numQueryThreads = numQueryThreads; + + this.bindPortRange = new HashSet<>(); + while (bindPortIterator.hasNext()) { + int port = bindPortIterator.next(); + Preconditions.checkArgument(port >= 0 && port <= 65535, + "Invalid port configuration. Port must be between 0 and 65535, but was " + port + "."); + bindPortRange.add(port); + } + } + + /** + * Creates a thread pool for the query execution. + * @return Thread pool for query execution + */ + private ExecutorService createQueryExecutor() { + ThreadFactory threadFactory = new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("Flink " + getServerName() + " Thread %d") + .build(); + return Executors.newFixedThreadPool(numQueryThreads, threadFactory); + } + + /** + * Returns the thread-pool responsible for processing incoming requests. + */ + protected ExecutorService getQueryExecutor() { + return queryExecutor; + } + + /** + * Gets the name of the server. This is useful for debugging. + * @return The name of the server. + */ + public String getServerName() { + return serverName; + } + + /** + * Returns the {@link AbstractServerHandler handler} to be used for + * serving the incoming requests. + */ + public abstract AbstractServerHandler<REQ, RESP> initializeHandler(); + + /** + * Returns the address of this server. + * + * @return AbstractServerBase address + * @throws IllegalStateException If server has not been started yet + */ + public InetSocketAddress getServerAddress() { + Preconditions.checkState(serverAddress != null, "Server " + serverName + " has not been started."); + return serverAddress; + } + + /** + * Starts the server by binding to the configured bind address (blocking). + * @throws Exception If something goes wrong during the bind operation. + */ + public void start() throws Throwable { + Preconditions.checkState(serverAddress == null, + "Server " + serverName + " already running @ " + serverAddress + '.'); + + Iterator<Integer> portIterator = bindPortRange.iterator(); + while (portIterator.hasNext() && !attemptToBind(portIterator.next())) {} + + if (serverAddress != null) { + LOG.info("Started server {} @ {}.", serverName, serverAddress); + } else { + LOG.info("Unable to start server {}. All ports in provided range are occupied.", serverName); + throw new FlinkRuntimeException("Unable to start server " + serverName + ". All ports in provided range are occupied."); + } + } + + /** + * Tries to start the server at the provided port. + * + * <p>This, in conjunction with {@link #start()}, try to start the + * server on a free port among the port range provided at the constructor. + * + * @param port the port to try to bind the server to. + * @throws Exception If something goes wrong during the bind operation. + */ + private boolean attemptToBind(final int port) throws Throwable { + LOG.debug("Attempting to start server {} on port {}.", serverName, port); + + this.queryExecutor = createQueryExecutor(); + this.handler = initializeHandler(); + + final NettyBufferPool bufferPool = new NettyBufferPool(numEventLoopThreads); + + final ThreadFactory threadFactory = new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("Flink " + serverName + " EventLoop Thread %d") + .build(); + + final NioEventLoopGroup nioGroup = new NioEventLoopGroup(numEventLoopThreads, threadFactory); + + this.bootstrap = new ServerBootstrap() + .localAddress(bindAddress, port) + .group(nioGroup) + .channel(NioServerSocketChannel.class) + .option(ChannelOption.ALLOCATOR, bufferPool) + .childOption(ChannelOption.ALLOCATOR, bufferPool) + .childOption(ChannelOption.WRITE_BUFFER_HIGH_WATER_MARK, HIGH_WATER_MARK) + .childOption(ChannelOption.WRITE_BUFFER_LOW_WATER_MARK, LOW_WATER_MARK) + .childHandler(new ServerChannelInitializer<>(handler)); + + try { + final ChannelFuture future = bootstrap.bind().sync(); + if (future.isSuccess()) { + final InetSocketAddress localAddress = (InetSocketAddress) future.channel().localAddress(); + serverAddress = new InetSocketAddress(localAddress.getAddress(), localAddress.getPort()); + return true; + } + + // the following throw is to bypass Netty's "optimization magic" + // and catch the bind exception. + // the exception is thrown by the sync() call above. + + throw future.cause(); + } catch (BindException e) { + LOG.debug("Failed to start server {} on port {}: {}.", serverName, port, e.getMessage()); + shutdown(); + } + // any other type of exception we let it bubble up. + return false; + } + + /** + * Shuts down the server and all related thread pools. + */ + public void shutdown() { + LOG.info("Shutting down server {} @ {}", serverName, serverAddress); + + if (handler != null) { + handler.shutdown(); + handler = null; + } + + if (queryExecutor != null) { + queryExecutor.shutdown(); + } + + if (bootstrap != null) { + EventLoopGroup group = bootstrap.group(); + if (group != null) { + group.shutdownGracefully(0L, 10L, TimeUnit.SECONDS); + } + } + serverAddress = null; + } + + /** + * Channel pipeline initializer. + * + * <p>The request handler is shared, whereas the other handlers are created + * per channel. + */ + private static final class ServerChannelInitializer<REQ extends MessageBody, RESP extends MessageBody> extends ChannelInitializer<SocketChannel> { + + /** The shared request handler. */ + private final AbstractServerHandler<REQ, RESP> sharedRequestHandler; + + /** + * Creates the channel pipeline initializer with the shared request handler. + * + * @param sharedRequestHandler Shared request handler. + */ + ServerChannelInitializer(AbstractServerHandler<REQ, RESP> sharedRequestHandler) { + this.sharedRequestHandler = Preconditions.checkNotNull(sharedRequestHandler, "MessageBody handler"); + } + + @Override + protected void initChannel(SocketChannel channel) throws Exception { + channel.pipeline() + .addLast(new ChunkedWriteHandler()) + .addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) + .addLast(sharedRequestHandler); + } + } + + @VisibleForTesting + public boolean isExecutorShutdown() { + return queryExecutor.isShutdown(); + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/0c771505/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/network/AbstractServerHandler.java ---------------------------------------------------------------------- diff --git a/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/network/AbstractServerHandler.java b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/network/AbstractServerHandler.java new file mode 100644 index 0000000..9e02291 --- /dev/null +++ b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/network/AbstractServerHandler.java @@ -0,0 +1,305 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.queryablestate.network; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.queryablestate.network.messages.MessageBody; +import org.apache.flink.queryablestate.network.messages.MessageSerializer; +import org.apache.flink.queryablestate.network.messages.MessageType; +import org.apache.flink.queryablestate.network.stats.KvStateRequestStats; +import org.apache.flink.util.ExceptionUtils; +import org.apache.flink.util.Preconditions; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFuture; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFutureListener; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandler; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandlerAdapter; +import org.apache.flink.shaded.netty4.io.netty.util.ReferenceCountUtil; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; + +/** + * The base class of every handler used by an {@link AbstractServerBase}. + * + * @param <REQ> the type of request the server expects to receive. + * @param <RESP> the type of response the server will send. + */ +@Internal [email protected] +public abstract class AbstractServerHandler<REQ extends MessageBody, RESP extends MessageBody> extends ChannelInboundHandlerAdapter { + + private static final Logger LOG = LoggerFactory.getLogger(AbstractServerHandler.class); + + /** The owning server of this handler. */ + private final AbstractServerBase<REQ, RESP> server; + + /** The serializer used to (de-)serialize messages. */ + private final MessageSerializer<REQ, RESP> serializer; + + /** Thread pool for query execution. */ + protected final ExecutorService queryExecutor; + + /** Exposed server statistics. */ + private final KvStateRequestStats stats; + + /** + * Create the handler. + * + * @param serializer the serializer used to (de-)serialize messages + * @param stats statistics collector + */ + public AbstractServerHandler( + final AbstractServerBase<REQ, RESP> server, + final MessageSerializer<REQ, RESP> serializer, + final KvStateRequestStats stats) { + + this.server = Preconditions.checkNotNull(server); + this.serializer = Preconditions.checkNotNull(serializer); + this.queryExecutor = server.getQueryExecutor(); + this.stats = Preconditions.checkNotNull(stats); + } + + protected String getServerName() { + return server.getServerName(); + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + stats.reportActiveConnection(); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + stats.reportInactiveConnection(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + REQ request = null; + long requestId = -1L; + + try { + final ByteBuf buf = (ByteBuf) msg; + final MessageType msgType = MessageSerializer.deserializeHeader(buf); + + requestId = MessageSerializer.getRequestId(buf); + + if (msgType == MessageType.REQUEST) { + + // ------------------------------------------------------------ + // MessageBody + // ------------------------------------------------------------ + request = serializer.deserializeRequest(buf); + stats.reportRequest(); + + // Execute actual query async, because it is possibly + // blocking (e.g. file I/O). + // + // A submission failure is not treated as fatal. + queryExecutor.submit(new AsyncRequestTask<>(this, ctx, requestId, request, stats)); + + } else { + // ------------------------------------------------------------ + // Unexpected + // ------------------------------------------------------------ + + final String errMsg = "Unexpected message type " + msgType + ". Expected " + MessageType.REQUEST + "."; + final ByteBuf failure = MessageSerializer.serializeServerFailure(ctx.alloc(), new IllegalArgumentException(errMsg)); + + LOG.debug(errMsg); + ctx.writeAndFlush(failure); + } + } catch (Throwable t) { + final String stringifiedCause = ExceptionUtils.stringifyException(t); + + String errMsg; + ByteBuf err; + if (request != null) { + errMsg = "Failed request with ID " + requestId + ". Caused by: " + stringifiedCause; + err = MessageSerializer.serializeRequestFailure(ctx.alloc(), requestId, new RuntimeException(errMsg)); + stats.reportFailedRequest(); + } else { + errMsg = "Failed incoming message. Caused by: " + stringifiedCause; + err = MessageSerializer.serializeServerFailure(ctx.alloc(), new RuntimeException(errMsg)); + } + + LOG.debug(errMsg); + ctx.writeAndFlush(err); + + } finally { + // IMPORTANT: We have to always recycle the incoming buffer. + // Otherwise we will leak memory out of Netty's buffer pool. + // + // If any operation ever holds on to the buffer, it is the + // responsibility of that operation to retain the buffer and + // release it later. + ReferenceCountUtil.release(msg); + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + final String msg = "Exception in server pipeline. Caused by: " + ExceptionUtils.stringifyException(cause); + final ByteBuf err = serializer.serializeServerFailure(ctx.alloc(), new RuntimeException(msg)); + + LOG.debug(msg); + ctx.writeAndFlush(err).addListener(ChannelFutureListener.CLOSE); + } + + /** + * Handles an incoming request and returns a {@link CompletableFuture} containing the corresponding response. + * + * <p><b>NOTE:</b> This method is called by multiple threads. + * + * @param requestId the id of the received request to be handled. + * @param request the request to be handled. + * @return A future with the response to be forwarded to the client. + */ + public abstract CompletableFuture<RESP> handleRequest(final long requestId, final REQ request); + + /** + * Shuts down any handler specific resources, e.g. thread pools etc. + */ + public abstract void shutdown(); + + /** + * Task to execute the actual query against the state instance. + */ + private static class AsyncRequestTask<REQ extends MessageBody, RESP extends MessageBody> implements Runnable { + + private final AbstractServerHandler<REQ, RESP> handler; + + private final ChannelHandlerContext ctx; + + private final long requestId; + + private final REQ request; + + private final KvStateRequestStats stats; + + private final long creationNanos; + + AsyncRequestTask( + final AbstractServerHandler<REQ, RESP> handler, + final ChannelHandlerContext ctx, + final long requestId, + final REQ request, + final KvStateRequestStats stats) { + + this.handler = Preconditions.checkNotNull(handler); + this.ctx = Preconditions.checkNotNull(ctx); + this.requestId = requestId; + this.request = Preconditions.checkNotNull(request); + this.stats = Preconditions.checkNotNull(stats); + this.creationNanos = System.nanoTime(); + } + + @Override + public void run() { + + if (!ctx.channel().isActive()) { + return; + } + + handler.handleRequest(requestId, request).whenComplete((resp, throwable) -> { + try { + if (throwable != null) { + throw throwable instanceof CompletionException + ? throwable.getCause() + : throwable; + } + + if (resp == null) { + throw new BadRequestException(handler.getServerName(), "NULL returned for request with ID " + requestId + "."); + } + + final ByteBuf serialResp = MessageSerializer.serializeResponse(ctx.alloc(), requestId, resp); + + int highWatermark = ctx.channel().config().getWriteBufferHighWaterMark(); + + ChannelFuture write; + if (serialResp.readableBytes() <= highWatermark) { + write = ctx.writeAndFlush(serialResp); + } else { + write = ctx.writeAndFlush(new ChunkedByteBuf(serialResp, highWatermark)); + } + write.addListener(new RequestWriteListener()); + + } catch (BadRequestException e) { + try { + stats.reportFailedRequest(); + final ByteBuf err = MessageSerializer.serializeRequestFailure(ctx.alloc(), requestId, e); + ctx.writeAndFlush(err); + } catch (IOException io) { + LOG.error("Failed to respond with the error after failed request", io); + } + } catch (Throwable t) { + try { + stats.reportFailedRequest(); + + final String errMsg = "Failed request " + requestId + ". Caused by: " + ExceptionUtils.stringifyException(t); + final ByteBuf err = MessageSerializer.serializeRequestFailure(ctx.alloc(), requestId, new RuntimeException(errMsg)); + ctx.writeAndFlush(err); + } catch (IOException io) { + LOG.error("Failed to respond with the error after failed request", io); + } + } + }); + } + + @Override + public String toString() { + return "AsyncRequestTask{" + + "requestId=" + requestId + + ", request=" + request + + '}'; + } + + /** + * Callback after query result has been written. + * + * <p>Gathers stats and logs errors. + */ + private class RequestWriteListener implements ChannelFutureListener { + + @Override + public void operationComplete(ChannelFuture future) throws Exception { + long durationNanos = System.nanoTime() - creationNanos; + long durationMillis = TimeUnit.MILLISECONDS.convert(durationNanos, TimeUnit.NANOSECONDS); + + if (future.isSuccess()) { + LOG.debug("Request {} was successfully answered after {} ms.", request, durationMillis); + stats.reportSuccessfulRequest(durationMillis); + } else { + LOG.debug("Request {} failed after {} ms : ", request, durationMillis, future.cause()); + stats.reportFailedRequest(); + } + } + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/0c771505/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/network/BadRequestException.java ---------------------------------------------------------------------- diff --git a/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/network/BadRequestException.java b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/network/BadRequestException.java new file mode 100644 index 0000000..3c0c484 --- /dev/null +++ b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/network/BadRequestException.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.queryablestate.network; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.util.Preconditions; + +/** + * Base class for exceptions thrown during querying Flink's managed state. + */ +@Internal +public class BadRequestException extends Exception { + + private static final long serialVersionUID = 3458743952407632903L; + + public BadRequestException(String serverName, String message) { + super(Preconditions.checkNotNull(serverName) + " : " + message); + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/0c771505/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/network/ChunkedByteBuf.java ---------------------------------------------------------------------- diff --git a/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/network/ChunkedByteBuf.java b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/network/ChunkedByteBuf.java new file mode 100644 index 0000000..9c56025 --- /dev/null +++ b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/network/ChunkedByteBuf.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.queryablestate.network; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.util.Preconditions; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; +import org.apache.flink.shaded.netty4.io.netty.handler.stream.ChunkedInput; +import org.apache.flink.shaded.netty4.io.netty.handler.stream.ChunkedWriteHandler; + +/** + * A {@link ByteBuf} instance to be consumed in chunks by {@link ChunkedWriteHandler}, + * respecting the high and low watermarks. + * + * @see <a href="http://normanmaurer.me/presentations/2014-facebook-eng-netty/slides.html#10.0">Low/High Watermarks</a> + */ +@Internal +public class ChunkedByteBuf implements ChunkedInput<ByteBuf> { + + /** The buffer to chunk. */ + private final ByteBuf buf; + + /** Size of chunks. */ + private final int chunkSize; + + /** Closed flag. */ + private boolean isClosed; + + /** End of input flag. */ + private boolean isEndOfInput; + + public ChunkedByteBuf(ByteBuf buf, int chunkSize) { + this.buf = Preconditions.checkNotNull(buf, "Buffer"); + Preconditions.checkArgument(chunkSize > 0, "Non-positive chunk size"); + this.chunkSize = chunkSize; + } + + @Override + public boolean isEndOfInput() throws Exception { + return isClosed || isEndOfInput; + } + + @Override + public void close() throws Exception { + if (!isClosed) { + // If we did not consume the whole buffer yet, we have to release + // it here. Otherwise, it's the responsibility of the consumer. + if (!isEndOfInput) { + buf.release(); + } + + isClosed = true; + } + } + + @Override + public ByteBuf readChunk(ChannelHandlerContext ctx) throws Exception { + if (isClosed) { + return null; + } else if (buf.readableBytes() <= chunkSize) { + isEndOfInput = true; + + // Don't retain as the consumer is responsible to release it + return buf.slice(); + } else { + // Return a chunk sized slice of the buffer. The ref count is + // shared with the original buffer. That's why we need to retain + // a reference here. + return buf.readSlice(chunkSize).retain(); + } + } + + @Override + public String toString() { + return "ChunkedByteBuf{" + + "buf=" + buf + + ", chunkSize=" + chunkSize + + ", isClosed=" + isClosed + + ", isEndOfInput=" + isEndOfInput + + '}'; + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/0c771505/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/network/Client.java ---------------------------------------------------------------------- diff --git a/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/network/Client.java b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/network/Client.java new file mode 100644 index 0000000..13d34fb --- /dev/null +++ b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/network/Client.java @@ -0,0 +1,536 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.queryablestate.network; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.queryablestate.FutureUtils; +import org.apache.flink.queryablestate.network.messages.MessageBody; +import org.apache.flink.queryablestate.network.messages.MessageSerializer; +import org.apache.flink.queryablestate.network.stats.KvStateRequestStats; +import org.apache.flink.util.Preconditions; + +import org.apache.flink.shaded.guava18.com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.apache.flink.shaded.netty4.io.netty.bootstrap.Bootstrap; +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBufAllocator; +import org.apache.flink.shaded.netty4.io.netty.channel.Channel; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFuture; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFutureListener; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInitializer; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelOption; +import org.apache.flink.shaded.netty4.io.netty.channel.EventLoopGroup; +import org.apache.flink.shaded.netty4.io.netty.channel.nio.NioEventLoopGroup; +import org.apache.flink.shaded.netty4.io.netty.channel.socket.SocketChannel; +import org.apache.flink.shaded.netty4.io.netty.channel.socket.nio.NioSocketChannel; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.LengthFieldBasedFrameDecoder; +import org.apache.flink.shaded.netty4.io.netty.handler.stream.ChunkedWriteHandler; + +import java.net.InetSocketAddress; +import java.nio.channels.ClosedChannelException; +import java.util.ArrayDeque; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; + +/** + * The base class for every client in the queryable state module. + * It is using pure netty to send and receive messages of type {@link MessageBody}. + * + * @param <REQ> the type of request the client will send. + * @param <RESP> the type of response the client expects to receive. + */ +@Internal +public class Client<REQ extends MessageBody, RESP extends MessageBody> { + + /** The name of the client. Used for logging and stack traces.*/ + private final String clientName; + + /** Netty's Bootstrap. */ + private final Bootstrap bootstrap; + + /** The serializer to be used for (de-)serializing messages. */ + private final MessageSerializer<REQ, RESP> messageSerializer; + + /** Statistics tracker. */ + private final KvStateRequestStats stats; + + /** Established connections. */ + private final Map<InetSocketAddress, EstablishedConnection> establishedConnections = new ConcurrentHashMap<>(); + + /** Pending connections. */ + private final Map<InetSocketAddress, PendingConnection> pendingConnections = new ConcurrentHashMap<>(); + + /** Atomic shut down flag. */ + private final AtomicBoolean shutDown = new AtomicBoolean(); + + /** + * Creates a client with the specified number of event loop threads. + * + * @param clientName the name of the client. + * @param numEventLoopThreads number of event loop threads (minimum 1). + * @param serializer the serializer used to (de-)serialize messages. + * @param stats the statistics collector. + */ + public Client( + final String clientName, + final int numEventLoopThreads, + final MessageSerializer<REQ, RESP> serializer, + final KvStateRequestStats stats) { + + Preconditions.checkArgument(numEventLoopThreads >= 1, + "Non-positive number of event loop threads."); + + this.clientName = Preconditions.checkNotNull(clientName); + this.messageSerializer = Preconditions.checkNotNull(serializer); + this.stats = Preconditions.checkNotNull(stats); + + final ThreadFactory threadFactory = new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("Flink " + clientName + " Event Loop Thread %d") + .build(); + + final EventLoopGroup nioGroup = new NioEventLoopGroup(numEventLoopThreads, threadFactory); + final ByteBufAllocator bufferPool = new NettyBufferPool(numEventLoopThreads); + + this.bootstrap = new Bootstrap() + .group(nioGroup) + .channel(NioSocketChannel.class) + .option(ChannelOption.ALLOCATOR, bufferPool) + .handler(new ChannelInitializer<SocketChannel>() { + @Override + protected void initChannel(SocketChannel channel) throws Exception { + channel.pipeline() + .addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) + .addLast(new ChunkedWriteHandler()); + } + }); + } + + public String getClientName() { + return clientName; + } + + public CompletableFuture<RESP> sendRequest(final InetSocketAddress serverAddress, final REQ request) { + if (shutDown.get()) { + return FutureUtils.getFailedFuture(new IllegalStateException("Shut down")); + } + + EstablishedConnection connection = establishedConnections.get(serverAddress); + if (connection != null) { + return connection.sendRequest(request); + } else { + PendingConnection pendingConnection = pendingConnections.get(serverAddress); + if (pendingConnection != null) { + // There was a race, use the existing pending connection. + return pendingConnection.sendRequest(request); + } else { + // We try to connect to the server. + PendingConnection pending = new PendingConnection(serverAddress, messageSerializer); + PendingConnection previous = pendingConnections.putIfAbsent(serverAddress, pending); + + if (previous == null) { + // OK, we are responsible to connect. + bootstrap.connect(serverAddress.getAddress(), serverAddress.getPort()).addListener(pending); + return pending.sendRequest(request); + } else { + // There was a race, use the existing pending connection. + return previous.sendRequest(request); + } + } + } + } + + /** + * Shuts down the client and closes all connections. + * + * <p>After a call to this method, all returned futures will be failed. + */ + public void shutdown() { + if (shutDown.compareAndSet(false, true)) { + for (Map.Entry<InetSocketAddress, EstablishedConnection> conn : establishedConnections.entrySet()) { + if (establishedConnections.remove(conn.getKey(), conn.getValue())) { + conn.getValue().close(); + } + } + + for (Map.Entry<InetSocketAddress, PendingConnection> conn : pendingConnections.entrySet()) { + if (pendingConnections.remove(conn.getKey()) != null) { + conn.getValue().close(); + } + } + + if (bootstrap != null) { + EventLoopGroup group = bootstrap.group(); + if (group != null) { + group.shutdownGracefully(0L, 10L, TimeUnit.SECONDS); + } + } + } + } + + /** + * A pending connection that is in the process of connecting. + */ + private class PendingConnection implements ChannelFutureListener { + + /** Lock to guard the connect call, channel hand in, etc. */ + private final Object connectLock = new Object(); + + /** Address of the server we are connecting to. */ + private final InetSocketAddress serverAddress; + + private final MessageSerializer<REQ, RESP> serializer; + + /** Queue of requests while connecting. */ + private final ArrayDeque<PendingRequest> queuedRequests = new ArrayDeque<>(); + + /** The established connection after the connect succeeds. */ + private EstablishedConnection established; + + /** Closed flag. */ + private boolean closed; + + /** Failure cause if something goes wrong. */ + private Throwable failureCause; + + /** + * Creates a pending connection to the given server. + * + * @param serverAddress Address of the server to connect to. + */ + private PendingConnection( + final InetSocketAddress serverAddress, + final MessageSerializer<REQ, RESP> serializer) { + this.serverAddress = serverAddress; + this.serializer = serializer; + } + + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + handInChannel(future.channel()); + } else { + close(future.cause()); + } + } + + /** + * Returns a future holding the serialized request result. + * + * <p>If the channel has been established, forward the call to the + * established channel, otherwise queue it for when the channel is + * handed in. + * + * @param request the request to be sent. + * @return Future holding the serialized result + */ + public CompletableFuture<RESP> sendRequest(REQ request) { + synchronized (connectLock) { + if (failureCause != null) { + return FutureUtils.getFailedFuture(failureCause); + } else if (closed) { + return FutureUtils.getFailedFuture(new ClosedChannelException()); + } else { + if (established != null) { + return established.sendRequest(request); + } else { + // Queue this and handle when connected + final PendingRequest pending = new PendingRequest(request); + queuedRequests.add(pending); + return pending; + } + } + } + } + + /** + * Hands in a channel after a successful connection. + * + * @param channel Channel to hand in + */ + private void handInChannel(Channel channel) { + synchronized (connectLock) { + if (closed || failureCause != null) { + // Close the channel and we are done. Any queued requests + // are removed on the close/failure call and after that no + // new ones can be enqueued. + channel.close(); + } else { + established = new EstablishedConnection(serverAddress, serializer, channel); + + while (!queuedRequests.isEmpty()) { + final PendingRequest pending = queuedRequests.poll(); + + established.sendRequest(pending.request) + .thenAccept(resp -> pending.complete(resp)) + .exceptionally(throwable -> { + pending.completeExceptionally(throwable); + return null; + }); + } + + // Publish the channel for the general public + establishedConnections.put(serverAddress, established); + pendingConnections.remove(serverAddress); + + // Check shut down for possible race with shut down. We + // don't want any lingering connections after shut down, + // which can happen if we don't check this here. + if (shutDown.get()) { + if (establishedConnections.remove(serverAddress, established)) { + established.close(); + } + } + } + } + } + + /** + * Close the connecting channel with a ClosedChannelException. + */ + private void close() { + close(new ClosedChannelException()); + } + + /** + * Close the connecting channel with an Exception (can be {@code null}) + * or forward to the established channel. + */ + private void close(Throwable cause) { + synchronized (connectLock) { + if (!closed) { + if (failureCause == null) { + failureCause = cause; + } + + if (established != null) { + established.close(); + } else { + PendingRequest pending; + while ((pending = queuedRequests.poll()) != null) { + pending.completeExceptionally(cause); + } + } + closed = true; + } + } + } + + @Override + public String toString() { + synchronized (connectLock) { + return "PendingConnection{" + + "serverAddress=" + serverAddress + + ", queuedRequests=" + queuedRequests.size() + + ", established=" + (established != null) + + ", closed=" + closed + + '}'; + } + } + + /** + * A pending request queued while the channel is connecting. + */ + private final class PendingRequest extends CompletableFuture<RESP> { + + private final REQ request; + + private PendingRequest(REQ request) { + this.request = request; + } + } + } + + /** + * An established connection that wraps the actual channel instance and is + * registered at the {@link ClientHandler} for callbacks. + */ + private class EstablishedConnection implements ClientHandlerCallback<RESP> { + + /** Address of the server we are connected to. */ + private final InetSocketAddress serverAddress; + + /** The actual TCP channel. */ + private final Channel channel; + + /** Pending requests keyed by request ID. */ + private final ConcurrentHashMap<Long, TimestampedCompletableFuture> pendingRequests = new ConcurrentHashMap<>(); + + /** Current request number used to assign unique request IDs. */ + private final AtomicLong requestCount = new AtomicLong(); + + /** Reference to a failure that was reported by the channel. */ + private final AtomicReference<Throwable> failureCause = new AtomicReference<>(); + + /** + * Creates an established connection with the given channel. + * + * @param serverAddress Address of the server connected to + * @param channel The actual TCP channel + */ + EstablishedConnection( + final InetSocketAddress serverAddress, + final MessageSerializer<REQ, RESP> serializer, + final Channel channel) { + + this.serverAddress = Preconditions.checkNotNull(serverAddress); + this.channel = Preconditions.checkNotNull(channel); + + // Add the client handler with the callback + channel.pipeline().addLast( + getClientName() + " Handler", + new ClientHandler<>(clientName, serializer, this) + ); + + stats.reportActiveConnection(); + } + + /** + * Close the channel with a ClosedChannelException. + */ + void close() { + close(new ClosedChannelException()); + } + + /** + * Close the channel with a cause. + * + * @param cause The cause to close the channel with. + * @return Channel close future + */ + private boolean close(Throwable cause) { + if (failureCause.compareAndSet(null, cause)) { + channel.close(); + stats.reportInactiveConnection(); + + for (long requestId : pendingRequests.keySet()) { + TimestampedCompletableFuture pending = pendingRequests.remove(requestId); + if (pending != null && pending.completeExceptionally(cause)) { + stats.reportFailedRequest(); + } + } + return true; + } + return false; + } + + /** + * Returns a future holding the serialized request result. + * @param request the request to be sent. + * @return Future holding the serialized result + */ + CompletableFuture<RESP> sendRequest(REQ request) { + TimestampedCompletableFuture requestPromiseTs = + new TimestampedCompletableFuture(System.nanoTime()); + try { + final long requestId = requestCount.getAndIncrement(); + pendingRequests.put(requestId, requestPromiseTs); + + stats.reportRequest(); + + ByteBuf buf = MessageSerializer.serializeRequest(channel.alloc(), requestId, request); + + channel.writeAndFlush(buf).addListener((ChannelFutureListener) future -> { + if (!future.isSuccess()) { + // Fail promise if not failed to write + TimestampedCompletableFuture pending = pendingRequests.remove(requestId); + if (pending != null && pending.completeExceptionally(future.cause())) { + stats.reportFailedRequest(); + } + } + }); + + // Check failure for possible race. We don't want any lingering + // promises after a failure, which can happen if we don't check + // this here. Note that close is treated as a failure as well. + Throwable failure = failureCause.get(); + if (failure != null) { + // Remove from pending requests to guard against concurrent + // removal and to make sure that we only count it once as failed. + TimestampedCompletableFuture pending = pendingRequests.remove(requestId); + if (pending != null && pending.completeExceptionally(failure)) { + stats.reportFailedRequest(); + } + } + } catch (Throwable t) { + requestPromiseTs.completeExceptionally(t); + } + + return requestPromiseTs; + } + + @Override + public void onRequestResult(long requestId, RESP response) { + TimestampedCompletableFuture pending = pendingRequests.remove(requestId); + if (pending != null && pending.complete(response)) { + long durationMillis = (System.nanoTime() - pending.getTimestamp()) / 1_000_000L; + stats.reportSuccessfulRequest(durationMillis); + } + } + + @Override + public void onRequestFailure(long requestId, Throwable cause) { + TimestampedCompletableFuture pending = pendingRequests.remove(requestId); + if (pending != null && pending.completeExceptionally(cause)) { + stats.reportFailedRequest(); + } + } + + @Override + public void onFailure(Throwable cause) { + if (close(cause)) { + // Remove from established channels, otherwise future + // requests will be handled by this failed channel. + establishedConnections.remove(serverAddress, this); + } + } + + @Override + public String toString() { + return "EstablishedConnection{" + + "serverAddress=" + serverAddress + + ", channel=" + channel + + ", pendingRequests=" + pendingRequests.size() + + ", requestCount=" + requestCount + + ", failureCause=" + failureCause + + '}'; + } + + /** + * Pair of promise and a timestamp. + */ + private class TimestampedCompletableFuture extends CompletableFuture<RESP> { + + private final long timestampInNanos; + + TimestampedCompletableFuture(long timestampInNanos) { + this.timestampInNanos = timestampInNanos; + } + + public long getTimestamp() { + return timestampInNanos; + } + } + } +}
