Handle SHUT_RD and SHUT_WR shutdown options.

In order to handle shutting down the send and receive sides
separately, we break the connection state into multiple sub-states.
This allows us to be partially connected (i.e. for either just
reads or just writes).

Support for SHUT_WR is needed to handle netperf properly, which
shuts down a socket by having the client use SHUT_WR, followed by
the server completing the disconnect with SHUT_RDWR.  The following
patch eliminates an error message from netperf:

'shutdown_control: no response received  errno 95'

Signed-off-by: Sean Hefty <sean.he...@intel.com>
---
 src/rsocket.c |  157 ++++++++++++++++++++++++++++++++++-----------------------
 1 files changed, 95 insertions(+), 62 deletions(-)

diff --git a/src/rsocket.c b/src/rsocket.c
index c833d46..c77dd4a 100644
--- a/src/rsocket.c
+++ b/src/rsocket.c
@@ -96,7 +96,8 @@ enum {
 #define rs_msg_data(imm_data) (imm_data & 0x1FFFFFFF)
 
 enum {
-       RS_CTRL_DISCONNECT
+       RS_CTRL_DISCONNECT,
+       RS_CTRL_SHUTDOWN
 };
 
 struct rs_msg {
@@ -136,16 +137,20 @@ union rs_wr_id {
  */
 enum rs_state {
        rs_init,
-       rs_bound,
-       rs_listening,
-       rs_resolving_addr,
-       rs_resolving_route,
-       rs_connecting,
-       rs_accepting,
-       rs_connect_error,
-       rs_connected,
-       rs_disconnected,
-       rs_error
+       rs_bound           =                0x0001,
+       rs_listening       =                0x0002,
+       rs_opening         =                0x0004,
+       rs_resolving_addr  = rs_opening |   0x0010,
+       rs_resolving_route = rs_opening |   0x0020,
+       rs_connecting      = rs_opening |   0x0040,
+       rs_accepting       = rs_opening |   0x0080,
+       rs_connected       =                0x0100,
+       rs_connect_wr      =                0x0200,
+       rs_connect_rd      =                0x0400,
+       rs_connect_rdwr    = rs_connected | rs_connect_rd | rs_connect_wr,
+       rs_connect_error   =                0x0800,
+       rs_disconnected    =                0x1000,
+       rs_error           =                0x2000,
 };
 
 #define RS_OPT_SWAP_SGL 1
@@ -162,7 +167,7 @@ struct rsocket {
        uint64_t          so_opts;
        uint64_t          tcp_opts;
        uint64_t          ipv6_opts;
-       enum rs_state     state;
+       int               state;
        int               cq_armed;
        int               retries;
        int               err;
@@ -321,7 +326,7 @@ static int rs_set_nonblocking(struct rsocket *rs, long arg)
        if (rs->cm_id->recv_cq_channel)
                ret = fcntl(rs->cm_id->recv_cq_channel->fd, F_SETFL, arg);
 
-       if (!ret && rs->state != rs_connected)
+       if (!ret && rs->state < rs_connected)
                ret = fcntl(rs->cm_id->channel->fd, F_SETFL, arg);
 
        return ret;
@@ -628,7 +633,7 @@ int raccept(int socket, struct sockaddr *addr, socklen_t 
*addrlen)
        rs_set_conn_data(new_rs, &param, &cresp);
        ret = rdma_accept(new_rs->cm_id, &param);
        if (!ret)
-               new_rs->state = rs_connected;
+               new_rs->state = rs_connect_rdwr;
        else if (errno == EAGAIN || errno == EWOULDBLOCK)
                new_rs->state = rs_accepting;
        else
@@ -715,7 +720,7 @@ connected:
                }
 
                rs_save_conn_data(rs, cresp);
-               rs->state = rs_connected;
+               rs->state = rs_connect_rdwr;
                break;
        case rs_accepting:
                if (!(rs->fd_flags & O_NONBLOCK))
@@ -725,7 +730,7 @@ connected:
                if (ret)
                        break;
 
-               rs->state = rs_connected;
+               rs->state = rs_connect_rdwr;
                break;
        default:
                ret = ERR(EINVAL);
@@ -752,6 +757,13 @@ int rconnect(int socket, const struct sockaddr *addr, 
socklen_t addrlen)
        return rs_do_connect(rs);
 }
 
+static void rs_shutdown_state(struct rsocket *rs, int state)
+{
+       rs->state &= ~state;
+       if (rs->state == rs_connected)
+               rs->state = rs_disconnected;
+}
+
 static int rs_post_write(struct rsocket *rs, uint64_t wr_id,
                         struct ibv_sge *sgl, int nsge,
                         uint32_t imm_data, int flags,
@@ -852,7 +864,7 @@ static int rs_give_credits(struct rsocket *rs)
 {
        return ((rs->rbuf_bytes_avail >= (rs->rbuf_size >> 1)) ||
                ((short) ((short) rs->rseq_no - (short) rs->rseq_comp) >= 0)) &&
-              rs->ctrl_avail && (rs->state == rs_connected);
+              rs->ctrl_avail && (rs->state & rs_connected);
 }
 
 static void rs_update_credits(struct rsocket *rs)
@@ -882,7 +894,9 @@ static int rs_poll_cq(struct rsocket *rs)
                        case RS_OP_CTRL:
                                if (rs_msg_data(imm_data) == 
RS_CTRL_DISCONNECT) {
                                        rs->state = rs_disconnected;
-                                       return ERR(ECONNRESET);
+                                       return 0;
+                               } else if (rs_msg_data(imm_data) == 
RS_CTRL_SHUTDOWN) {
+                                       rs_shutdown_state(rs, rs_connect_rd);
                                }
                                break;
                        default:
@@ -900,14 +914,14 @@ static int rs_poll_cq(struct rsocket *rs)
                        } else {
                                rs->ctrl_avail++;
                        }
-                       if (wc.status != IBV_WC_SUCCESS && rs->state == 
rs_connected) {
+                       if (wc.status != IBV_WC_SUCCESS && (rs->state & 
rs_connected)) {
                                rs->state = rs_error;
                                rs->err = EIO;
                        }
                }
        }
 
-       if (rs->state == rs_connected) {
+       if (rs->state & rs_connected) {
                while (!ret && rcnt--)
                        ret = rdma_post_recvv(rs->cm_id, NULL, NULL, 0);
 
@@ -932,7 +946,7 @@ static int rs_get_cq_event(struct rsocket *rs)
        if (!ret) {
                ibv_ack_cq_events(rs->cm_id->recv_cq, 1);
                rs->cq_armed = 0;
-       } else if (errno != EAGAIN && rs->state == rs_connected) {
+       } else if (errno != EAGAIN) {
                rs->state = rs_error;
        }
 
@@ -1043,7 +1057,7 @@ static int rs_can_send(struct rsocket *rs)
 
 static int rs_conn_can_send(struct rsocket *rs)
 {
-       return rs_can_send(rs) || (rs->state != rs_connected);
+       return rs_can_send(rs) || !(rs->state & rs_connect_wr);
 }
 
 static int rs_can_send_ctrl(struct rsocket *rs)
@@ -1058,7 +1072,7 @@ static int rs_have_rdata(struct rsocket *rs)
 
 static int rs_conn_have_rdata(struct rsocket *rs)
 {
-       return rs_have_rdata(rs) || (rs->state != rs_connected);
+       return rs_have_rdata(rs) || !(rs->state & rs_connect_rd);
 }
 
 static int rs_all_sends_done(struct rsocket *rs)
@@ -1111,7 +1125,7 @@ ssize_t rrecv(int socket, void *buf, size_t len, int 
flags)
        int ret;
 
        rs = idm_at(&idm, socket);
-       if (rs->state < rs_connected) {
+       if (rs->state & rs_opening) {
                ret = rs_do_connect(rs);
                if (ret) {
                        if (errno == EINPROGRESS)
@@ -1122,7 +1136,7 @@ ssize_t rrecv(int socket, void *buf, size_t len, int 
flags)
        fastlock_acquire(&rs->rlock);
        if (!rs_have_rdata(rs)) {
                ret = rs_get_comp(rs, rs_nonblocking(rs, flags), 
rs_conn_have_rdata);
-               if (ret && errno != ECONNRESET)
+               if (ret)
                        goto out;
        }
 
@@ -1213,7 +1227,7 @@ ssize_t rsend(int socket, const void *buf, size_t len, 
int flags)
        int ret = 0;
 
        rs = idm_at(&idm, socket);
-       if (rs->state < rs_connected) {
+       if (rs->state & rs_opening) {
                ret = rs_do_connect(rs);
                if (ret) {
                        if (errno == EINPROGRESS)
@@ -1229,7 +1243,7 @@ ssize_t rsend(int socket, const void *buf, size_t len, 
int flags)
                                          rs_conn_can_send);
                        if (ret)
                                break;
-                       if (rs->state != rs_connected) {
+                       if (!(rs->state & rs_connect_wr)) {
                                ret = ERR(ECONNRESET);
                                break;
                        }
@@ -1322,7 +1336,7 @@ static ssize_t rsendv(int socket, const struct iovec 
*iov, int iovcnt, int flags
        int i, ret = 0;
 
        rs = idm_at(&idm, socket);
-       if (rs->state < rs_connected) {
+       if (rs->state & rs_opening) {
                ret = rs_do_connect(rs);
                if (ret) {
                        if (errno == EINPROGRESS)
@@ -1343,7 +1357,7 @@ static ssize_t rsendv(int socket, const struct iovec 
*iov, int iovcnt, int flags
                                          rs_conn_can_send);
                        if (ret)
                                break;
-                       if (rs->state != rs_connected) {
+                       if (!(rs->state & rs_connect_wr)) {
                                ret = ERR(ECONNRESET);
                                break;
                        }
@@ -1435,17 +1449,35 @@ static int rs_poll_rs(struct rsocket *rs, int events,
        short revents;
        int ret;
 
-       switch (rs->state) {
-       case rs_listening:
+check_cq:
+       if ((rs->state & rs_connected) || (rs->state == rs_disconnected) ||
+           (rs->state & rs_error)) {
+               rs_process_cq(rs, nonblock, test);
+
+               revents = 0;
+               if ((events & POLLIN) && rs_conn_have_rdata(rs))
+                       revents |= POLLIN;
+               if ((events & POLLOUT) && rs_can_send(rs))
+                       revents |= POLLOUT;
+               if (!(rs->state & rs_connected)) {
+                       if (rs->state == rs_disconnected)
+                               revents |= POLLHUP;
+                       else
+                               revents |= POLLERR;
+               }
+
+               return revents;
+       }
+
+       if (rs->state == rs_listening) {
                fds.fd = rs->cm_id->channel->fd;
                fds.events = events;
                fds.revents = 0;
                poll(&fds, 1, 0);
                return fds.revents;
-       case rs_resolving_addr:
-       case rs_resolving_route:
-       case rs_connecting:
-       case rs_accepting:
+       }
+
+       if (rs->state & rs_opening) {
                ret = rs_do_connect(rs);
                if (ret) {
                        if (errno == EINPROGRESS) {
@@ -1455,28 +1487,13 @@ static int rs_poll_rs(struct rsocket *rs, int events,
                                return POLLOUT;
                        }
                }
-               /* fall through */
-       case rs_connected:
-       case rs_disconnected:
-       case rs_error:
-               rs_process_cq(rs, nonblock, test);
-
-               revents = 0;
-               if ((events & POLLIN) && rs_have_rdata(rs))
-                       revents |= POLLIN;
-               if ((events & POLLOUT) && rs_can_send(rs))
-                       revents |= POLLOUT;
-               if (rs->state == rs_disconnected)
-                       revents |= POLLHUP;
-               else if (rs->state == rs_error)
-                       revents |= POLLERR;
+               goto check_cq;
+       }
 
-               return revents;
-       case rs_connect_error:
+       if (rs->state == rs_connect_error)
                return (rs->err && events & POLLOUT) ? POLLOUT : 0;
-       default:
-               return 0;
-       }
+
+       return 0;
 }
 
 static int rs_poll_check(struct pollfd *fds, nfds_t nfds)
@@ -1688,14 +1705,26 @@ int rselect(int nfds, fd_set *readfds, fd_set *writefds,
 int rshutdown(int socket, int how)
 {
        struct rsocket *rs;
-       int ret = 0;
+       int ctrl, ret = 0;
 
        rs = idm_at(&idm, socket);
+       if (how == SHUT_RD) {
+               rs_shutdown_state(rs, rs_connect_rd);
+               return 0;
+       }
+
        if (rs->fd_flags & O_NONBLOCK)
                rs_set_nonblocking(rs, 0);
 
-       if (rs->state == rs_connected) {
-               rs->state = rs_disconnected;
+       if (rs->state & rs_connected) {
+               if (how == SHUT_RDWR) {
+                       ctrl = RS_CTRL_DISCONNECT;
+                       rs->state = rs_disconnected;
+               } else {
+                       rs_shutdown_state(rs, rs_connect_wr);
+                       ctrl = (rs->state & rs_connected) ?
+                               RS_CTRL_SHUTDOWN : RS_CTRL_DISCONNECT;
+               }
                if (!rs_can_send_ctrl(rs)) {
                        ret = rs_process_cq(rs, 0, rs_can_send_ctrl);
                        if (ret)
@@ -1704,13 +1733,16 @@ int rshutdown(int socket, int how)
 
                rs->ctrl_avail--;
                ret = rs_post_write(rs, 0, NULL, 0,
-                                   rs_msg_set(RS_OP_CTRL, RS_CTRL_DISCONNECT),
+                                   rs_msg_set(RS_OP_CTRL, ctrl),
                                    0, 0, 0);
        }
 
-       if (!rs_all_sends_done(rs) && rs->state != rs_error)
+       if (!rs_all_sends_done(rs) && !(rs->state & rs_error))
                rs_process_cq(rs, 0, rs_all_sends_done);
 
+       if ((rs->fd_flags & O_NONBLOCK) && (rs->state & rs_connected))
+               rs_set_nonblocking(rs, 1);
+
        return 0;
 }
 
@@ -1719,7 +1751,7 @@ int rclose(int socket)
        struct rsocket *rs;
 
        rs = idm_at(&idm, socket);
-       if (rs->state == rs_connected)
+       if (rs->state & rs_connected)
                rshutdown(socket, SHUT_RDWR);
 
        rs_free(rs);
@@ -1830,8 +1862,9 @@ int rsetsockopt(int socket, int level, int optname,
                default:
                        break;
                }
+               break;
        case SOL_RDMA:
-               if (rs->state > rs_listening) {
+               if (rs->state >= rs_opening) {
                        ret = ERR(EINVAL);
                        break;
                }


--
To unsubscribe from this list: send the line "unsubscribe linux-rdma" in
the body of a message to majord...@vger.kernel.org
More majordomo info at  http://vger.kernel.org/majordomo-info.html

Reply via email to