Currently the socket code will unlink any UNIX socket path which is
associated with a server socket. This is not fine grained enough, as we
need to avoid unlinking server sockets we were passed by systemd.

Signed-off-by: Daniel P. Berrangé <berra...@redhat.com>
---
 src/locking/lock_daemon.c     |  1 +
 src/logging/log_daemon.c      |  1 +
 src/rpc/virnetserverservice.c |  3 ++
 src/rpc/virnetserverservice.h |  1 +
 src/rpc/virnetsocket.c        | 57 ++++++++++++++++++++---------------
 src/rpc/virnetsocket.h        |  1 +
 6 files changed, 40 insertions(+), 24 deletions(-)

diff --git a/src/locking/lock_daemon.c b/src/locking/lock_daemon.c
index c10b2d383c..0f90606be6 100644
--- a/src/locking/lock_daemon.c
+++ b/src/locking/lock_daemon.c
@@ -619,6 +619,7 @@ virLockDaemonSetupNetworkingSystemD(virNetServerPtr 
lockSrv, virNetServerPtr adm
          * so the first FD we'll get is '3'. */
         if (!(svc = virNetServerServiceNewFDs(fds,
                                               ARRAY_CARDINALITY(fds),
+                                              false,
                                               0,
                                               NULL,
                                               false, 0, 1)))
diff --git a/src/logging/log_daemon.c b/src/logging/log_daemon.c
index 6531999381..30c70a20dd 100644
--- a/src/logging/log_daemon.c
+++ b/src/logging/log_daemon.c
@@ -554,6 +554,7 @@ virLogDaemonSetupNetworkingSystemD(virNetServerPtr logSrv, 
virNetServerPtr admin
          * so the first FD we'll get is '3'. */
         if (!(svc = virNetServerServiceNewFDs(fds,
                                               ARRAY_CARDINALITY(fds),
+                                              false,
                                               0,
                                               NULL,
                                               false, 0, 1)))
diff --git a/src/rpc/virnetserverservice.c b/src/rpc/virnetserverservice.c
index 0d2f264696..315a4950df 100644
--- a/src/rpc/virnetserverservice.c
+++ b/src/rpc/virnetserverservice.c
@@ -121,6 +121,7 @@ virNetServerServiceNewFDOrUNIX(const char *path,
          */
         return virNetServerServiceNewFDs(fds,
                                          ARRAY_CARDINALITY(fds),
+                                         false,
                                          auth,
                                          tls,
                                          readonly,
@@ -257,6 +258,7 @@ virNetServerServicePtr virNetServerServiceNewUNIX(const 
char *path,
 
 virNetServerServicePtr virNetServerServiceNewFDs(int *fds,
                                                  size_t nfds,
+                                                 bool unlinkUNIX,
                                                  int auth,
                                                  virNetTLSContextPtr tls,
                                                  bool readonly,
@@ -272,6 +274,7 @@ virNetServerServicePtr virNetServerServiceNewFDs(int *fds,
 
     for (i = 0; i < nfds; i++) {
         if (virNetSocketNewListenFD(fds[i],
+                                    unlinkUNIX,
                                     &socks[i]) < 0)
             goto cleanup;
     }
diff --git a/src/rpc/virnetserverservice.h b/src/rpc/virnetserverservice.h
index 59ee51e5ee..73d61dde99 100644
--- a/src/rpc/virnetserverservice.h
+++ b/src/rpc/virnetserverservice.h
@@ -62,6 +62,7 @@ virNetServerServicePtr virNetServerServiceNewUNIX(const char 
*path,
                                                   size_t nrequests_client_max);
 virNetServerServicePtr virNetServerServiceNewFDs(int *fd,
                                                  size_t nfds,
+                                                 bool unlinkUNIX,
                                                  int auth,
                                                  virNetTLSContextPtr tls,
                                                  bool readonly,
diff --git a/src/rpc/virnetsocket.c b/src/rpc/virnetsocket.c
index fc13b1654a..a462c3eb05 100644
--- a/src/rpc/virnetsocket.c
+++ b/src/rpc/virnetsocket.c
@@ -81,6 +81,7 @@ struct _virNetSocket {
     bool client;
     bool ownsFd;
     bool quietEOF;
+    bool unlinkUNIX;
 
     /* Event callback fields */
     virNetSocketIOFunc func;
@@ -216,10 +217,13 @@ int virNetSocketCheckProtocols(bool *hasIPv4,
 }
 
 
-static virNetSocketPtr virNetSocketNew(virSocketAddrPtr localAddr,
-                                       virSocketAddrPtr remoteAddr,
-                                       bool isClient,
-                                       int fd, int errfd, pid_t pid)
+static virNetSocketPtr
+virNetSocketNew(virSocketAddrPtr localAddr,
+                virSocketAddrPtr remoteAddr,
+                int fd,
+                int errfd,
+                pid_t pid,
+                bool unlinkUNIX)
 {
     virNetSocketPtr sock;
     int no_slow_start = 1;
@@ -254,6 +258,7 @@ static virNetSocketPtr virNetSocketNew(virSocketAddrPtr 
localAddr,
     sock->pid = pid;
     sock->watch = -1;
     sock->ownsFd = true;
+    sock->unlinkUNIX = unlinkUNIX;
 
     /* Disable nagle for TCP sockets */
     if (sock->localAddr.data.sa.sa_family == AF_INET ||
@@ -280,8 +285,6 @@ static virNetSocketPtr virNetSocketNew(virSocketAddrPtr 
localAddr,
         !(sock->remoteAddrStrURI = virSocketAddrFormatFull(remoteAddr, true, 
NULL)))
         goto error;
 
-    sock->client = isClient;
-
     PROBE(RPC_SOCKET_NEW,
           "sock=%p fd=%d errfd=%d pid=%lld localAddr=%s, remoteAddr=%s",
           sock, fd, errfd, (long long)pid,
@@ -427,7 +430,7 @@ int virNetSocketNewListenTCP(const char *nodename,
         if (VIR_EXPAND_N(socks, nsocks, 1) < 0)
             goto error;
 
-        if (!(socks[nsocks-1] = virNetSocketNew(&addr, NULL, false, fd, -1, 
0)))
+        if (!(socks[nsocks-1] = virNetSocketNew(&addr, NULL, fd, -1, 0, 
false)))
             goto error;
         runp = runp->ai_next;
         fd = -1;
@@ -513,7 +516,7 @@ int virNetSocketNewListenUNIX(const char *path,
         goto error;
     }
 
-    if (!(*retsock = virNetSocketNew(&addr, NULL, false, fd, -1, 0)))
+    if (!(*retsock = virNetSocketNew(&addr, NULL, fd, -1, 0, true)))
         goto error;
 
     return 0;
@@ -538,6 +541,7 @@ int virNetSocketNewListenUNIX(const char *path 
ATTRIBUTE_UNUSED,
 #endif
 
 int virNetSocketNewListenFD(int fd,
+                            bool unlinkUNIX,
                             virNetSocketPtr *retsock)
 {
     virSocketAddr addr;
@@ -551,7 +555,7 @@ int virNetSocketNewListenFD(int fd,
         return -1;
     }
 
-    if (!(*retsock = virNetSocketNew(&addr, NULL, false, fd, -1, 0)))
+    if (!(*retsock = virNetSocketNew(&addr, NULL, fd, -1, 0, unlinkUNIX)))
         return -1;
 
     return 0;
@@ -627,7 +631,7 @@ int virNetSocketNewConnectTCP(const char *nodename,
         goto error;
     }
 
-    if (!(*retsock = virNetSocketNew(&localAddr, &remoteAddr, true, fd, -1, 
0)))
+    if (!(*retsock = virNetSocketNew(&localAddr, &remoteAddr, fd, -1, 0, 
false)))
         goto error;
 
     freeaddrinfo(ai);
@@ -752,7 +756,7 @@ int virNetSocketNewConnectUNIX(const char *path,
         goto cleanup;
     }
 
-    if (!(*retsock = virNetSocketNew(&localAddr, &remoteAddr, true, fd, -1, 
0)))
+    if (!(*retsock = virNetSocketNew(&localAddr, &remoteAddr, fd, -1, 0, 
false)))
         goto cleanup;
 
     ret = 0;
@@ -820,7 +824,7 @@ int virNetSocketNewConnectCommand(virCommandPtr cmd,
     VIR_FORCE_CLOSE(sv[1]);
     VIR_FORCE_CLOSE(errfd[1]);
 
-    if (!(*retsock = virNetSocketNew(NULL, NULL, true, sv[0], errfd[0], pid)))
+    if (!(*retsock = virNetSocketNew(NULL, NULL, sv[0], errfd[0], pid, false)))
         goto error;
 
     virCommandFree(cmd);
@@ -1219,7 +1223,7 @@ int virNetSocketNewConnectSockFD(int sockfd,
         return -1;
     }
 
-    if (!(*retsock = virNetSocketNew(&localAddr, NULL, true, sockfd, -1, -1)))
+    if (!(*retsock = virNetSocketNew(&localAddr, NULL, sockfd, -1, -1, false)))
         return -1;
 
     return 0;
@@ -1231,7 +1235,7 @@ virNetSocketPtr 
virNetSocketNewPostExecRestart(virJSONValuePtr object)
     virSocketAddr localAddr;
     virSocketAddr remoteAddr;
     int fd, thepid, errfd;
-    bool isClient;
+    bool unlinkUNIX;
 
     if (virJSONValueObjectGetNumberInt(object, "fd", &fd) < 0) {
         virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
@@ -1250,10 +1254,15 @@ virNetSocketPtr 
virNetSocketNewPostExecRestart(virJSONValuePtr object)
                        _("Missing errfd data in JSON document"));
         return NULL;
     }
-    if (virJSONValueObjectGetBoolean(object, "isClient", &isClient) < 0) {
-        virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
-                       _("Missing isClient data in JSON document"));
-        return NULL;
+
+    if (virJSONValueObjectGetBoolean(object, "unlinkUNIX", &unlinkUNIX) < 0) {
+        bool isClient;
+        if (virJSONValueObjectGetBoolean(object, "isClient", &isClient) < 0) {
+            virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
+                           _("Missing unlinkUNIX/isClient data in JSON 
document"));
+            return NULL;
+        }
+        unlinkUNIX = !isClient;
     }
 
     memset(&localAddr, 0, sizeof(localAddr));
@@ -1272,7 +1281,7 @@ virNetSocketPtr 
virNetSocketNewPostExecRestart(virJSONValuePtr object)
     }
 
     return virNetSocketNew(&localAddr, &remoteAddr,
-                           isClient, fd, errfd, thepid);
+                           fd, errfd, thepid, unlinkUNIX);
 }
 
 
@@ -1309,7 +1318,7 @@ virJSONValuePtr 
virNetSocketPreExecRestart(virNetSocketPtr sock)
     if (virJSONValueObjectAppendNumberInt(object, "pid", sock->pid) < 0)
         goto error;
 
-    if (virJSONValueObjectAppendBoolean(object, "isClient", sock->client) < 0)
+    if (virJSONValueObjectAppendBoolean(object, "unlinkUNIX", 
sock->unlinkUNIX) < 0)
         goto error;
 
     if (virSetInherit(sock->fd, true) < 0) {
@@ -1350,7 +1359,7 @@ void virNetSocketDispose(void *obj)
 
 #ifdef HAVE_SYS_UN_H
     /* If a server socket, then unlink UNIX path */
-    if (!sock->client &&
+    if (sock->unlinkUNIX &&
         sock->localAddr.data.sa.sa_family == AF_UNIX &&
         sock->localAddr.data.un.sun_path[0] != '\0')
         unlink(sock->localAddr.data.un.sun_path);
@@ -2140,8 +2149,8 @@ int virNetSocketAccept(virNetSocketPtr sock, 
virNetSocketPtr *clientsock)
 
     if (!(*clientsock = virNetSocketNew(&localAddr,
                                         &remoteAddr,
-                                        true,
-                                        fd, -1, 0)))
+                                        fd, -1, 0,
+                                        false)))
         goto cleanup;
 
     fd = -1;
@@ -2272,7 +2281,7 @@ void virNetSocketClose(virNetSocketPtr sock)
 
 #ifdef HAVE_SYS_UN_H
     /* If a server socket, then unlink UNIX path */
-    if (!sock->client &&
+    if (sock->unlinkUNIX &&
         sock->localAddr.data.sa.sa_family == AF_UNIX &&
         sock->localAddr.data.un.sun_path[0] != '\0') {
         if (unlink(sock->localAddr.data.un.sun_path) == 0)
diff --git a/src/rpc/virnetsocket.h b/src/rpc/virnetsocket.h
index de5a465cde..2f626cb08f 100644
--- a/src/rpc/virnetsocket.h
+++ b/src/rpc/virnetsocket.h
@@ -58,6 +58,7 @@ int virNetSocketNewListenUNIX(const char *path,
                               virNetSocketPtr *addr);
 
 int virNetSocketNewListenFD(int fd,
+                            bool unlinkUNIX,
                             virNetSocketPtr *addr);
 
 int virNetSocketNewConnectTCP(const char *nodename,
-- 
2.21.0

--
libvir-list mailing list
libvir-list@redhat.com
https://www.redhat.com/mailman/listinfo/libvir-list

Reply via email to