From: "Daniel P. Berrange" <berra...@redhat.com>

Currently if the keepalive timer triggers, the 'markClose'
flag is set on the virNetClient. A controlled shutdown will
then be performed. If an I/O error occurs during read or
write of the connection an error is raised back to the
caller, but the connection isn't marked for close. This
patch ensures that all I/O error scenarios always result
in the connection being marked for close.

Signed-off-by: Daniel P. Berrange <berra...@redhat.com>
---
 src/rpc/virnetclient.c |   62 +++++++++++++++++++++++++++++++++++++++---------
 1 file changed, 51 insertions(+), 11 deletions(-)

diff --git a/src/rpc/virnetclient.c b/src/rpc/virnetclient.c
index 0e7e423..a258746 100644
--- a/src/rpc/virnetclient.c
+++ b/src/rpc/virnetclient.c
@@ -101,6 +101,7 @@ struct _virNetClient {
 
     virKeepAlivePtr keepalive;
     bool wantClose;
+    int closeReason;
 };
 
 
@@ -108,6 +109,8 @@ static void 
virNetClientIOEventLoopPassTheBuck(virNetClientPtr client,
                                                virNetClientCallPtr thiscall);
 static int virNetClientQueueNonBlocking(virNetClientPtr client,
                                         virNetMessagePtr msg);
+static void virNetClientCloseInternal(virNetClientPtr client,
+                                      int reason);
 
 
 static void virNetClientLock(virNetClientPtr client)
@@ -261,7 +264,7 @@ virNetClientKeepAliveStop(virNetClientPtr client)
 static void
 virNetClientKeepAliveDeadCB(void *opaque)
 {
-    virNetClientClose(opaque);
+    virNetClientCloseInternal(opaque, VIR_CONNECT_CLOSE_REASON_KEEPALIVE);
 }
 
 static int
@@ -484,16 +487,26 @@ void virNetClientFree(virNetClientPtr client)
 
 
 static void
+virNetClientMarkClose(virNetClientPtr client,
+                      int reason)
+{
+    VIR_DEBUG("client=%p, reason=%d", client, reason);
+    virNetSocketRemoveIOCallback(client->sock);
+    client->wantClose = true;
+    client->closeReason = reason;
+}
+
+
+static void
 virNetClientCloseLocked(virNetClientPtr client)
 {
     virKeepAlivePtr ka;
 
-    VIR_DEBUG("client=%p, sock=%p", client, client->sock);
+    VIR_DEBUG("client=%p, sock=%p, reason=%d", client, client->sock, 
client->closeReason);
 
     if (!client->sock)
         return;
 
-    virNetSocketRemoveIOCallback(client->sock);
     virNetSocketFree(client->sock);
     client->sock = NULL;
     virNetTLSSessionFree(client->tls);
@@ -518,16 +531,21 @@ virNetClientCloseLocked(virNetClientPtr client)
     }
 }
 
-void virNetClientClose(virNetClientPtr client)
+static void virNetClientCloseInternal(virNetClientPtr client,
+                                      int reason)
 {
     VIR_DEBUG("client=%p", client);
 
     if (!client)
         return;
 
+    if (!client->sock ||
+        client->wantClose)
+        return;
+
     virNetClientLock(client);
 
-    client->wantClose = true;
+    virNetClientMarkClose(client, reason);
 
     /* If there is a thread polling for data on the socket, wake the thread up
      * otherwise try to pass the buck to a possibly waiting thread. If no
@@ -548,6 +566,12 @@ void virNetClientClose(virNetClientPtr client)
 }
 
 
+void virNetClientClose(virNetClientPtr client)
+{
+    virNetClientCloseInternal(client, VIR_CONNECT_CLOSE_REASON_CLIENT);
+}
+
+
 #if HAVE_SASL
 void virNetClientSetSASLSession(virNetClientPtr client,
                                 virNetSASLSessionPtr sasl)
@@ -1351,7 +1375,7 @@ static int virNetClientIOEventLoop(virNetClientPtr client,
         }
 
         if (virKeepAliveTrigger(client->keepalive, &msg)) {
-            client->wantClose = true;
+            virNetClientMarkClose(client, VIR_CONNECT_CLOSE_REASON_KEEPALIVE);
         } else if (msg && virNetClientQueueNonBlocking(client, msg) < 0) {
             VIR_WARN("Could not queue keepalive request");
             virNetMessageFree(msg);
@@ -1374,18 +1398,23 @@ static int virNetClientIOEventLoop(virNetClientPtr 
client,
             if (saferead(client->wakeupReadFD, &ignore, sizeof(ignore)) != 
sizeof(ignore)) {
                 virReportSystemError(errno, "%s",
                                      _("read on wakeup fd failed"));
+                virNetClientMarkClose(client, VIR_CONNECT_CLOSE_REASON_ERROR);
                 goto error;
             }
         }
 
         if (fds[0].revents & POLLOUT) {
-            if (virNetClientIOHandleOutput(client) < 0)
+            if (virNetClientIOHandleOutput(client) < 0) {
+                virNetClientMarkClose(client, VIR_CONNECT_CLOSE_REASON_ERROR);
                 goto error;
+            }
         }
 
         if (fds[0].revents & POLLIN) {
-            if (virNetClientIOHandleInput(client) < 0)
+            if (virNetClientIOHandleInput(client) < 0) {
+                virNetClientMarkClose(client, VIR_CONNECT_CLOSE_REASON_ERROR);
                 goto error;
+            }
         }
 
         /* Iterate through waiting calls and if any are
@@ -1410,6 +1439,7 @@ static int virNetClientIOEventLoop(virNetClientPtr client,
         }
 
         if (fds[0].revents & (POLLHUP | POLLERR)) {
+            virNetClientMarkClose(client, VIR_CONNECT_CLOSE_REASON_EOF);
             virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
                            _("received hangup / error event on socket"));
             goto error;
@@ -1441,6 +1471,9 @@ static void virNetClientIOUpdateCallback(virNetClientPtr 
client,
 {
     int events = 0;
 
+    if (client->wantClose)
+        return;
+
     if (enableCallback) {
         events |= VIR_EVENT_HANDLE_READABLE;
         virNetClientCallMatchPredicate(client->waitDispatch,
@@ -1623,6 +1656,8 @@ void virNetClientIncomingEvent(virNetSocketPtr sock,
 
     virNetClientLock(client);
 
+    VIR_DEBUG("client=%p wantclose=%d", client, client ? client->wantClose : 
false);
+
     if (!client->sock)
         goto done;
 
@@ -1635,18 +1670,21 @@ void virNetClientIncomingEvent(virNetSocketPtr sock,
     if (events & (VIR_EVENT_HANDLE_HANGUP | VIR_EVENT_HANDLE_ERROR)) {
         VIR_DEBUG("%s : VIR_EVENT_HANDLE_HANGUP or "
                   "VIR_EVENT_HANDLE_ERROR encountered", __FUNCTION__);
-        virNetSocketRemoveIOCallback(sock);
+        virNetClientMarkClose(client,
+                              (events & VIR_EVENT_HANDLE_HANGUP) ?
+                              VIR_CONNECT_CLOSE_REASON_EOF :
+                              VIR_CONNECT_CLOSE_REASON_ERROR);
         goto done;
     }
 
     if (events & VIR_EVENT_HANDLE_WRITABLE) {
         if (virNetClientIOHandleOutput(client) < 0)
-            virNetSocketRemoveIOCallback(sock);
+            virNetClientMarkClose(client, VIR_CONNECT_CLOSE_REASON_ERROR);
     }
 
     if (events & VIR_EVENT_HANDLE_READABLE) {
         if (virNetClientIOHandleInput(client) < 0)
-            virNetSocketRemoveIOCallback(sock);
+            virNetClientMarkClose(client, VIR_CONNECT_CLOSE_REASON_ERROR);
     }
 
     /* Remove completed calls or signal their threads. */
@@ -1656,6 +1694,8 @@ void virNetClientIncomingEvent(virNetSocketPtr sock,
     virNetClientIOUpdateCallback(client, true);
 
 done:
+    if (client->wantClose)
+        virNetClientCloseLocked(client);
     virNetClientUnlock(client);
 }
 
-- 
1.7.10.4

--
libvir-list mailing list
libvir-list@redhat.com
https://www.redhat.com/mailman/listinfo/libvir-list

Reply via email to