This is an automated email from the ASF dual-hosted git repository.

ivandasch pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ignite.git


The following commit(s) were added to refs/heads/master by this push:
     new d02626640e0  IGNITE-20759 Fix writing post-handshake message in NIO 
server (#11036)
d02626640e0 is described below

commit d02626640e05772ca61faabe7f9c1718016bc264
Author: Ivan Daschinskiy <ivanda...@apache.org>
AuthorDate: Fri Nov 17 14:48:40 2023 +0300

     IGNITE-20759 Fix writing post-handshake message in NIO server (#11036)
---
 .../ignite/internal/util/nio/GridNioServer.java    |   3 +-
 .../internal/util/nio/ssl/BlockingSslHandler.java  |   4 +-
 .../tcp/internal/GridNioServerWrapper.java         |   6 -
 .../tcp/internal/TcpHandshakeExecutor.java         | 367 ++++++++++++---------
 .../tcp/GridTcpCommunicationSpiAbstractTest.java   |   2 +-
 .../tcp/TcpCommunicationHandshakeTimeoutTest.java  |   3 +-
 6 files changed, 224 insertions(+), 161 deletions(-)

diff --git 
a/modules/core/src/main/java/org/apache/ignite/internal/util/nio/GridNioServer.java
 
b/modules/core/src/main/java/org/apache/ignite/internal/util/nio/GridNioServer.java
index 695eaed5228..fc23516c8af 100644
--- 
a/modules/core/src/main/java/org/apache/ignite/internal/util/nio/GridNioServer.java
+++ 
b/modules/core/src/main/java/org/apache/ignite/internal/util/nio/GridNioServer.java
@@ -1425,7 +1425,8 @@ public class GridNioServer<T> {
             try {
                 boolean writeFinished = writeSslSystem(ses, sockCh);
 
-                if (!handshakeFinished) {
+                // If post-handshake message is not written fully (possible on 
JDK 17), we should retry.
+                if (!handshakeFinished || !writeFinished) {
                     if (writeFinished)
                         stopPollingForWrite(key, ses);
 
diff --git 
a/modules/core/src/main/java/org/apache/ignite/internal/util/nio/ssl/BlockingSslHandler.java
 
b/modules/core/src/main/java/org/apache/ignite/internal/util/nio/ssl/BlockingSslHandler.java
index 40ded3c314b..1da7c8ab62a 100644
--- 
a/modules/core/src/main/java/org/apache/ignite/internal/util/nio/ssl/BlockingSslHandler.java
+++ 
b/modules/core/src/main/java/org/apache/ignite/internal/util/nio/ssl/BlockingSslHandler.java
@@ -85,8 +85,8 @@ public class BlockingSslHandler {
         SocketChannel ch,
         boolean directBuf,
         ByteOrder order,
-        IgniteLogger log)
-        throws SSLException {
+        IgniteLogger log
+    ) {
         this.ch = ch;
         this.log = log;
         this.sslEngine = sslEngine;
diff --git 
a/modules/core/src/main/java/org/apache/ignite/spi/communication/tcp/internal/GridNioServerWrapper.java
 
b/modules/core/src/main/java/org/apache/ignite/spi/communication/tcp/internal/GridNioServerWrapper.java
index 8302b1c9508..187b7fc3edb 100644
--- 
a/modules/core/src/main/java/org/apache/ignite/spi/communication/tcp/internal/GridNioServerWrapper.java
+++ 
b/modules/core/src/main/java/org/apache/ignite/spi/communication/tcp/internal/GridNioServerWrapper.java
@@ -1186,12 +1186,6 @@ public class GridNioServerWrapper {
         try {
             return tcpHandshakeExecutor.tcpHandshake(ch, rmtNodeId, sslMeta, 
msg);
         }
-        catch (IOException e) {
-            if (log.isDebugEnabled())
-                log.debug("Failed to read from channel: " + e);
-
-            throw new IgniteCheckedException("Failed to read from channel.", 
e);
-        }
         finally {
             if (!timeoutObject.cancel())
                 throw handshakeTimeoutException();
diff --git 
a/modules/core/src/main/java/org/apache/ignite/spi/communication/tcp/internal/TcpHandshakeExecutor.java
 
b/modules/core/src/main/java/org/apache/ignite/spi/communication/tcp/internal/TcpHandshakeExecutor.java
index cb5e256503e..387ceed0fdb 100644
--- 
a/modules/core/src/main/java/org/apache/ignite/spi/communication/tcp/internal/TcpHandshakeExecutor.java
+++ 
b/modules/core/src/main/java/org/apache/ignite/spi/communication/tcp/internal/TcpHandshakeExecutor.java
@@ -22,6 +22,7 @@ import java.nio.ByteBuffer;
 import java.nio.ByteOrder;
 import java.nio.channels.SocketChannel;
 import java.util.UUID;
+import javax.net.ssl.SSLException;
 import org.apache.ignite.IgniteCheckedException;
 import org.apache.ignite.IgniteLogger;
 import org.apache.ignite.internal.util.nio.ssl.BlockingSslHandler;
@@ -67,228 +68,296 @@ public class TcpHandshakeExecutor {
      * @param ch Socket channel which using for handshake.
      * @param rmtNodeId Expected remote node.
      * @param sslMeta Required data for ssl.
-     * @param msg Handshake message which should be send during handshake.
+     * @param msg Handshake message which should be sent during handshake.
      * @return Handshake response from predefined variants from {@link 
RecoveryLastReceivedMessage}.
-     * @throws IgniteCheckedException If not related to IO exception happened.
-     * @throws IOException If reading or writing to socket is failed.
+     * @throws IgniteCheckedException If handshake failed.
      */
     public long tcpHandshake(
         SocketChannel ch,
         UUID rmtNodeId,
         GridSslMeta sslMeta,
         HandshakeMessage msg
-    ) throws IgniteCheckedException, IOException {
-        long rcvCnt;
+    ) throws IgniteCheckedException {
+        BlockingTransport transport = stateProvider.isSslEnabled() ?
+            new SslTransport(sslMeta, ch, directBuffer, log) : new 
TcpTransport(ch);
 
-        BlockingSslHandler sslHnd = null;
+        ByteBuffer buf = transport.recieveNodeId();
 
-        ByteBuffer buf;
+        if (buf == null)
+            return NEED_WAIT;
 
-        // Step 1. Get remote node response with the remote nodeId value.
-        if (stateProvider.isSslEnabled()) {
-            assert sslMeta != null;
-
-            sslHnd = new BlockingSslHandler(sslMeta.sslEngine(), ch, 
directBuffer, ByteOrder.LITTLE_ENDIAN, log);
-
-            if (!sslHnd.handshake())
-                throw new HandshakeException("SSL handshake is not 
completed.");
-
-            ByteBuffer handBuff = sslHnd.applicationBuffer();
-
-            if (handBuff.remaining() >= DIRECT_TYPE_SIZE) {
-                short msgType = makeMessageType(handBuff.get(0), 
handBuff.get(1));
-
-                if (msgType == HANDSHAKE_WAIT_MSG_TYPE)
-                    return NEED_WAIT;
-            }
-
-            if (handBuff.remaining() < NodeIdMessage.MESSAGE_FULL_SIZE) {
-                ByteBuffer readBuf = ByteBuffer.allocate(1000);
-
-                while (handBuff.remaining() < NodeIdMessage.MESSAGE_FULL_SIZE) 
{
-                    int read = ch.read(readBuf);
-
-                    if (read == -1)
-                        throw new HandshakeException("Failed to read remote 
node ID (connection closed).");
+        UUID rmtNodeId0 = U.bytesToUuid(buf.array(), DIRECT_TYPE_SIZE);
 
-                    readBuf.flip();
+        if (!rmtNodeId.equals(rmtNodeId0))
+            throw new HandshakeException("Remote node ID is not as expected 
[expected=" + rmtNodeId + ", rcvd=" + rmtNodeId0 + ']');
+        else if (log.isDebugEnabled())
+            log.debug("Received remote node ID: " + rmtNodeId0);
 
-                    sslHnd.decode(readBuf);
+        if (log.isDebugEnabled())
+            log.debug("Writing handshake message [rmtNode=" + rmtNodeId + ", 
msg=" + msg + ']');
 
-                    if (handBuff.remaining() >= DIRECT_TYPE_SIZE) {
-                        break;
-                    }
+        transport.sendHandshake(msg);
 
-                    readBuf.flip();
-                }
+        buf = transport.recieveAcknowledge();
 
-                buf = handBuff;
+        long rcvCnt = buf.getLong(DIRECT_TYPE_SIZE);
 
-                if (handBuff.remaining() >= DIRECT_TYPE_SIZE) {
-                    short msgType = makeMessageType(handBuff.get(0), 
handBuff.get(1));
+        if (log.isDebugEnabled())
+            log.debug("Received handshake message [rmtNode=" + rmtNodeId + ", 
rcvCnt=" + rcvCnt + ']');
 
-                    if (msgType == HANDSHAKE_WAIT_MSG_TYPE)
-                        return NEED_WAIT;
-                }
-            }
-            else
-                buf = handBuff;
+        if (rcvCnt == -1) {
+            if (log.isDebugEnabled())
+                log.debug("Connection rejected, will retry client creation 
[rmtNode=" + rmtNodeId + ']');
         }
-        else {
-            buf = ByteBuffer.allocate(NodeIdMessage.MESSAGE_FULL_SIZE);
 
-            for (int i = 0; i < NodeIdMessage.MESSAGE_FULL_SIZE; ) {
-                int read = ch.read(buf);
+        transport.onHandshakeFinished(sslMeta);
+
+        return rcvCnt;
+    }
 
-                if (read == -1)
+    /**
+     * Encapsulates handshake logic.
+     */
+    private abstract static class BlockingTransport {
+        /**
+         * Receive {@link NodeIdMessage}.
+         *
+         * @return Buffer with {@link NodeIdMessage}.
+         * @throws IgniteCheckedException If failed.
+         */
+        ByteBuffer recieveNodeId() throws IgniteCheckedException {
+            ByteBuffer buf = 
ByteBuffer.allocate(NodeIdMessage.MESSAGE_FULL_SIZE)
+                    .order(ByteOrder.LITTLE_ENDIAN);
+
+            for (int totalBytes = 0; totalBytes < 
NodeIdMessage.MESSAGE_FULL_SIZE; ) {
+                int readBytes = read(buf);
+
+                if (readBytes == -1)
                     throw new HandshakeException("Failed to read remote node 
ID (connection closed).");
 
-                if (read >= DIRECT_TYPE_SIZE) {
+                if (readBytes >= DIRECT_TYPE_SIZE) {
                     short msgType = makeMessageType(buf.get(0), buf.get(1));
 
                     if (msgType == HANDSHAKE_WAIT_MSG_TYPE)
-                        return NEED_WAIT;
+                        return null;
                 }
 
-                i += read;
+                totalBytes += readBytes;
             }
-        }
-
-        UUID rmtNodeId0 = U.bytesToUuid(buf.array(), DIRECT_TYPE_SIZE);
 
-        if (!rmtNodeId.equals(rmtNodeId0))
-            throw new HandshakeException("Remote node ID is not as expected 
[expected=" + rmtNodeId +
-                ", rcvd=" + rmtNodeId0 + ']');
-        else if (log.isDebugEnabled())
-            log.debug("Received remote node ID: " + rmtNodeId0);
-
-        if (stateProvider.isSslEnabled()) {
-            assert sslHnd != null;
+            return buf;
+        }
 
-            U.writeFully(ch, sslHnd.encrypt(ByteBuffer.wrap(U.IGNITE_HEADER)));
+        /**
+         * Send {@link HandshakeMessage} to remote node.
+         *
+         * @param msg Handshake message.
+         * @throws IgniteCheckedException If failed.
+         */
+        void sendHandshake(HandshakeMessage msg) throws IgniteCheckedException 
{
+            ByteBuffer buf = ByteBuffer.allocate(msg.getMessageSize() + 
U.IGNITE_HEADER.length)
+                    .order(ByteOrder.LITTLE_ENDIAN)
+                    .put(U.IGNITE_HEADER);
+
+            msg.writeTo(buf, null);
+            buf.flip();
+
+            write(buf);
         }
-        else
-            U.writeFully(ch, ByteBuffer.wrap(U.IGNITE_HEADER));
 
-        // Step 2. Prepare Handshake message to send to the remote node.
-        if (log.isDebugEnabled())
-            log.debug("Writing handshake message [rmtNode=" + rmtNodeId + ", 
msg=" + msg + ']');
+        /**
+         * Receive {@link RecoveryLastReceivedMessage} acknowledge message.
+         *
+         * @return Buffer with message.
+         * @throws IgniteCheckedException If failed.
+         */
+        ByteBuffer recieveAcknowledge() throws IgniteCheckedException {
+            ByteBuffer buf = 
ByteBuffer.allocate(RecoveryLastReceivedMessage.MESSAGE_FULL_SIZE)
+                    .order(ByteOrder.LITTLE_ENDIAN);
 
-        buf = ByteBuffer.allocate(msg.getMessageSize());
+            for (int totalBytes = 0; totalBytes < 
RecoveryLastReceivedMessage.MESSAGE_FULL_SIZE; ) {
+                int readBytes = read(buf);
 
-        buf.order(ByteOrder.LITTLE_ENDIAN);
+                if (readBytes == -1)
+                    throw new HandshakeException("Failed to read remote node 
recovery handshake " +
+                            "(connection closed).");
 
-        boolean written = msg.writeTo(buf, null);
+                totalBytes += readBytes;
+            }
 
-        assert written;
+            return buf;
+        }
 
-        buf.flip();
+        /**
+         * Read data from media.
+         *
+         * @param buf Buffer to read into.
+         * @return Bytes read.
+         * @throws IgniteCheckedException If failed.
+         */
+        abstract int read(ByteBuffer buf) throws IgniteCheckedException;
+
+        /**
+         * Write data fully.
+         * @param buf Buffer to write.
+         * @throws IgniteCheckedException If failed.
+         */
+        abstract void write(ByteBuffer buf) throws IgniteCheckedException;
+
+        /**
+         * Do some post-handshake job if needed.
+         *
+         * @param sslMeta Ssl meta.
+         */
+        void onHandshakeFinished(GridSslMeta sslMeta) {
+            // No-op.
+        }
+    }
 
-        if (stateProvider.isSslEnabled()) {
-            assert sslHnd != null;
+    /**
+     * Tcp plaintext transport.
+     */
+    private static class TcpTransport extends BlockingTransport {
+        /** */
+        private final SocketChannel ch;
 
-            U.writeFully(ch, sslHnd.encrypt(buf));
+        /** */
+        TcpTransport(SocketChannel ch) {
+            this.ch = ch;
         }
-        else
-            U.writeFully(ch, buf);
 
-        if (log.isDebugEnabled())
-            log.debug("Waiting for handshake [rmtNode=" + rmtNodeId + ']');
-
-        // Step 3. Waiting for response from the remote node with their 
receive count message.
-        if (stateProvider.isSslEnabled()) {
-            assert sslHnd != null;
+        /** {@inheritDoc} */
+        @Override int read(ByteBuffer buf) throws IgniteCheckedException {
+            try {
+                return ch.read(buf);
+            }
+            catch (IOException e) {
+                throw new IgniteCheckedException("Failed to read from 
channel", e);
+            }
+        }
 
-            buf = ByteBuffer.allocate(1000);
-            buf.order(ByteOrder.LITTLE_ENDIAN);
+        /** {@inheritDoc} */
+        @Override void write(ByteBuffer buf) throws IgniteCheckedException {
+            try {
+                U.writeFully(ch, buf);
+            }
+            catch (IOException e) {
+                throw new IgniteCheckedException("Failed to write to channel", 
e);
+            }
+        }
+    }
 
-            ByteBuffer decode = ByteBuffer.allocate(2 * buf.capacity());
-            decode.order(ByteOrder.LITTLE_ENDIAN);
+    /** Ssl transport */
+    private static class SslTransport extends BlockingTransport {
+        /** */
+        private static final int READ_BUFFER_CAPACITY = 1024;
 
-            for (int i = 0; i < RecoveryLastReceivedMessage.MESSAGE_FULL_SIZE; 
) {
-                int read = ch.read(buf);
+        /** */
+        private final BlockingSslHandler handler;
 
-                if (read == -1)
-                    throw new HandshakeException("Failed to read remote node 
recovery handshake " +
-                        "(connection closed).");
+        /** */
+        private final SocketChannel ch;
 
-                buf.flip();
+        /** */
+        private final ByteBuffer readBuf;
 
-                ByteBuffer decode0 = sslHnd.decode(buf);
+        /** */
+        SslTransport(GridSslMeta meta, SocketChannel ch, boolean directBuf, 
IgniteLogger log) throws IgniteCheckedException {
+            try {
+                this.ch = ch;
+                handler = new BlockingSslHandler(meta.sslEngine(), ch, 
directBuf, ByteOrder.LITTLE_ENDIAN, log);
 
-                i += decode0.remaining();
+                if (!handler.handshake())
+                    throw new HandshakeException("SSL handshake is not 
completed.");
 
-                decode = appendAndResizeIfNeeded(decode, decode0);
+                readBuf = directBuf ? 
ByteBuffer.allocateDirect(READ_BUFFER_CAPACITY) : 
ByteBuffer.allocate(READ_BUFFER_CAPACITY);
 
-                buf.clear();
+                readBuf.order(ByteOrder.LITTLE_ENDIAN);
+            }
+            catch (SSLException e) {
+                throw new IgniteCheckedException("SSL handhshake failed", e);
             }
+        }
 
-            decode.flip();
+        /** {@inheritDoc} */
+        @Override int read(ByteBuffer buf) throws IgniteCheckedException {
+            ByteBuffer appBuff = handler.applicationBuffer();
 
-            rcvCnt = decode.getLong(DIRECT_TYPE_SIZE);
+            int read = copy(appBuff, buf);
 
-            if (decode.limit() > 
RecoveryLastReceivedMessage.MESSAGE_FULL_SIZE) {
-                decode.position(RecoveryLastReceivedMessage.MESSAGE_FULL_SIZE);
+            if (read > 0)
+                return read;
 
-                sslMeta.decodedBuffer(decode);
-            }
+            try {
+                while (read == 0) {
+                    readBuf.clear();
 
-            ByteBuffer inBuf = sslHnd.inputBuffer();
+                    if (ch.read(readBuf) < 0)
+                        return -1;
 
-            if (inBuf.position() > 0)
-                sslMeta.encodedBuffer(inBuf);
-        }
-        else {
-            buf = 
ByteBuffer.allocate(RecoveryLastReceivedMessage.MESSAGE_FULL_SIZE);
+                    readBuf.flip();
 
-            buf.order(ByteOrder.LITTLE_ENDIAN);
+                    handler.decode(readBuf);
 
-            for (int i = 0; i < RecoveryLastReceivedMessage.MESSAGE_FULL_SIZE; 
) {
-                int read = ch.read(buf);
+                    read = copy(appBuff, buf);
+                }
+            }
+            catch (SSLException e) {
+                throw new IgniteCheckedException("Failed to decrypt data", e);
+            }
+            catch (IOException e) {
+                throw new IgniteCheckedException("Failed to read from 
channel", e);
+            }
 
-                if (read == -1)
-                    throw new HandshakeException("Failed to read remote node 
recovery handshake " +
-                        "(connection closed).");
+            return read;
+        }
 
-                i += read;
+        /** {@inheritDoc} */
+        @Override void write(ByteBuffer buf) throws IgniteCheckedException {
+            try {
+                U.writeFully(ch, handler.encrypt(buf));
+            }
+            catch (SSLException e) {
+                throw new IgniteCheckedException("Failed to encrypt data", e);
+            }
+            catch (IOException e) {
+                throw new IgniteCheckedException("Failed to write to channel", 
e);
             }
-
-            rcvCnt = buf.getLong(DIRECT_TYPE_SIZE);
         }
 
-        if (log.isDebugEnabled())
-            log.debug("Received handshake message [rmtNode=" + rmtNodeId + ", 
rcvCnt=" + rcvCnt + ']');
+        /** {@inheritDoc} */
+        @Override void onHandshakeFinished(GridSslMeta sslMeta) {
+            ByteBuffer appBuff = handler.applicationBuffer();
+            if (appBuff.hasRemaining())
+                sslMeta.decodedBuffer(appBuff);
 
-        if (rcvCnt == -1) {
-            if (log.isDebugEnabled())
-                log.debug("Connection rejected, will retry client creation 
[rmtNode=" + rmtNodeId + ']');
+            ByteBuffer inBuf = handler.inputBuffer();
+
+            if (inBuf.position() > 0)
+                sslMeta.encodedBuffer(inBuf);
         }
 
-        return rcvCnt;
-    }
+        /**
+         * @param src Source buffer.
+         * @param dst Destination buffer.
+         * @return Bytes copied.
+         */
+        private int copy(ByteBuffer src, ByteBuffer dst) {
+            int remaining = Math.min(src.remaining(), dst.remaining());
 
-    /**
-     * @param target Target buffer to append to.
-     * @param src Source buffer to get data.
-     * @return Original or expanded buffer.
-     */
-    private ByteBuffer appendAndResizeIfNeeded(ByteBuffer target, ByteBuffer 
src) {
-        if (target.remaining() < src.remaining()) {
-            int newSize = Math.max(target.capacity() * 2, target.capacity() + 
src.remaining());
+            if (remaining > 0) {
+                int oldLimit = src.limit();
 
-            ByteBuffer tmp = ByteBuffer.allocate(newSize);
+                src.limit(src.position() + remaining);
 
-            tmp.order(target.order());
+                dst.put(src);
 
-            target.flip();
+                src.limit(oldLimit);
+            }
 
-            tmp.put(target);
+            src.compact();
 
-            target = tmp;
+            return remaining;
         }
-
-        target.put(src);
-
-        return target;
     }
 }
diff --git 
a/modules/core/src/test/java/org/apache/ignite/spi/communication/tcp/GridTcpCommunicationSpiAbstractTest.java
 
b/modules/core/src/test/java/org/apache/ignite/spi/communication/tcp/GridTcpCommunicationSpiAbstractTest.java
index a70078d3951..5752f790957 100644
--- 
a/modules/core/src/test/java/org/apache/ignite/spi/communication/tcp/GridTcpCommunicationSpiAbstractTest.java
+++ 
b/modules/core/src/test/java/org/apache/ignite/spi/communication/tcp/GridTcpCommunicationSpiAbstractTest.java
@@ -76,7 +76,7 @@ abstract class GridTcpCommunicationSpiAbstractTest extends 
GridAbstractCommunica
 
         // Test idle clients remove.
         for (CommunicationSpi<Message> spi : spis.values()) {
-            ConcurrentMap<UUID, GridCommunicationClient> clients = 
GridTestUtils.getFieldValue(spi, "clientPool", "clients");
+            ConcurrentMap<UUID, GridCommunicationClient[]> clients = 
GridTestUtils.getFieldValue(spi, "clientPool", "clients");
 
             assertEquals(getSpiCount() - 1, clients.size());
 
diff --git 
a/modules/core/src/test/java/org/apache/ignite/spi/communication/tcp/TcpCommunicationHandshakeTimeoutTest.java
 
b/modules/core/src/test/java/org/apache/ignite/spi/communication/tcp/TcpCommunicationHandshakeTimeoutTest.java
index 819dce188df..c8eba14075c 100644
--- 
a/modules/core/src/test/java/org/apache/ignite/spi/communication/tcp/TcpCommunicationHandshakeTimeoutTest.java
+++ 
b/modules/core/src/test/java/org/apache/ignite/spi/communication/tcp/TcpCommunicationHandshakeTimeoutTest.java
@@ -17,7 +17,6 @@
 
 package org.apache.ignite.spi.communication.tcp;
 
-import java.io.IOException;
 import java.nio.channels.SocketChannel;
 import java.util.Arrays;
 import java.util.HashSet;
@@ -137,7 +136,7 @@ public class TcpCommunicationHandshakeTimeoutTest extends 
GridCommonAbstractTest
 
         /** {@inheritDoc} */
         @Override public long tcpHandshake(SocketChannel ch, UUID rmtNodeId, 
GridSslMeta sslMeta,
-            HandshakeMessage msg) throws IgniteCheckedException, IOException {
+            HandshakeMessage msg) throws IgniteCheckedException {
             if (needToDelayd.get()) {
                 needToDelayd.set(false);
 

Reply via email to