This extends the basic virNetSocket APIs to allow them to have
a handle to the TLS/SASL session objects, once established.
This ensures that any data reads/writes are automagically
passed through the TLS/SASL encryption layers if required.

* src/rpc/virnetsocket.c, src/rpc/virnetsocket.h: Wire up
  SASL/TLS encryption
---
 src/rpc/virnetsocket.c |  276 +++++++++++++++++++++++++++++++++++++++++++++++-
 src/rpc/virnetsocket.h |   11 ++
 2 files changed, 284 insertions(+), 3 deletions(-)

diff --git a/src/rpc/virnetsocket.c b/src/rpc/virnetsocket.c
index a0eb431..a5ee861 100644
--- a/src/rpc/virnetsocket.c
+++ b/src/rpc/virnetsocket.c
@@ -27,6 +27,9 @@
 #include <sys/socket.h>
 #include <unistd.h>
 #include <sys/wait.h>
+#ifdef HAVE_NETINET_TCP_H
+# include <netinet/tcp.h>
+#endif
 
 #include "virnetsocket.h"
 #include "util.h"
@@ -55,6 +58,19 @@ struct _virNetSocket {
     virSocketAddr remoteAddr;
     char *localAddrStr;
     char *remoteAddrStr;
+
+    virNetTLSSessionPtr tlsSession;
+#if HAVE_SASL
+    virNetSASLSessionPtr saslSession;
+
+    const char *saslDecoded;
+    size_t saslDecodedLength;
+    size_t saslDecodedOffset;
+
+    const char *saslEncoded;
+    size_t saslEncodedLength;
+    size_t saslEncodedOffset;
+#endif
 };
 
 
@@ -394,7 +410,7 @@ error:
 }
 
 
-#if HAVE_SYS_UN_H
+#ifdef HAVE_SYS_UN_H
 int virNetSocketNewConnectUNIX(const char *path,
                                bool spawnDaemon,
                                const char *binary,
@@ -610,6 +626,14 @@ void virNetSocketFree(virNetSocketPtr sock)
         unlink(sock->localAddr.data.un.sun_path);
 #endif
 
+    /* Make sure it can't send any more I/O during shutdown */
+    if (sock->tlsSession)
+        virNetTLSSessionSetIOCallbacks(sock->tlsSession, NULL, NULL, NULL);
+    virNetTLSSessionFree(sock->tlsSession);
+#if HAVE_SASL
+    virNetSASLSessionFree(sock->saslSession);
+#endif
+
     VIR_FORCE_CLOSE(sock->fd);
     VIR_FORCE_CLOSE(sock->errfd);
 
@@ -695,14 +719,260 @@ const char *virNetSocketRemoteAddrString(virNetSocketPtr 
sock)
     return sock->remoteAddrStr;
 }
 
-ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len)
+
+static ssize_t virNetSocketTLSSessionWrite(const char *buf,
+                                           size_t len,
+                                           void *opaque)
 {
+    virNetSocketPtr sock = opaque;
+    return write(sock->fd, buf, len);
+}
+
+
+static ssize_t virNetSocketTLSSessionRead(char *buf,
+                                          size_t len,
+                                          void *opaque)
+{
+    virNetSocketPtr sock = opaque;
     return read(sock->fd, buf, len);
 }
 
+
+void virNetSocketSetTLSSession(virNetSocketPtr sock,
+                               virNetTLSSessionPtr sess)
+{
+    if (sock->tlsSession)
+        virNetTLSSessionFree(sock->tlsSession);
+    sock->tlsSession = sess;
+    virNetTLSSessionSetIOCallbacks(sess,
+                                   virNetSocketTLSSessionWrite,
+                                   virNetSocketTLSSessionRead,
+                                   sock);
+    virNetTLSSessionRef(sess);
+}
+
+
+#if HAVE_SASL
+void virNetSocketSetSASLSession(virNetSocketPtr sock,
+                                virNetSASLSessionPtr sess)
+{
+    if (sock->saslSession)
+        virNetSASLSessionFree(sock->saslSession);
+    sock->saslSession = sess;
+    virNetSASLSessionRef(sess);
+}
+#endif
+
+
+bool virNetSocketHasCachedData(virNetSocketPtr sock ATTRIBUTE_UNUSED)
+{
+#if HAVE_SASL
+    if (sock->saslDecoded)
+        return true;
+#endif
+    return false;
+}
+
+
+static ssize_t virNetSocketReadWire(virNetSocketPtr sock, char *buf, size_t 
len)
+{
+    char *errout = NULL;
+    ssize_t ret;
+reread:
+    if (sock->tlsSession &&
+        virNetTLSSessionGetHandshakeStatus(sock->tlsSession) ==
+        VIR_NET_TLS_HANDSHAKE_COMPLETE) {
+        ret = virNetTLSSessionRead(sock->tlsSession, buf, len);
+    } else {
+        ret = read(sock->fd, buf, len);
+    }
+
+    if ((ret < 0) && (errno == EINTR))
+        goto reread;
+    if ((ret < 0) && (errno == EAGAIN))
+        return 0;
+
+    if (ret <= 0 &&
+        sock->errfd != -1 &&
+        virFileReadLimFD(sock->errfd, 1024, &errout) >= 0 &&
+        errout != NULL) {
+        size_t elen = strlen(errout);
+        if (elen && errout[elen-1] == '\n')
+            errout[elen-1] = '\0';
+    }
+
+    if (ret < 0) {
+        if (errout)
+            virReportSystemError(errno,
+                                 _("Cannot recv data: %s"), errout);
+        else
+            virReportSystemError(errno, "%s",
+                                 _("Cannot recv data"));
+        ret = -1;
+    } else if (ret == 0) {
+        if (errout)
+            virReportSystemError(EIO,
+                                 _("End of file while reading data: %s"), 
errout);
+        else
+            virReportSystemError(EIO, "%s",
+                                 _("End of file while reading data"));
+        ret = -1;
+    }
+
+    VIR_FREE(errout);
+    return ret;
+}
+
+static ssize_t virNetSocketWriteWire(virNetSocketPtr sock, const char *buf, 
size_t len)
+{
+    ssize_t ret;
+rewrite:
+    if (sock->tlsSession &&
+        virNetTLSSessionGetHandshakeStatus(sock->tlsSession) ==
+        VIR_NET_TLS_HANDSHAKE_COMPLETE) {
+        ret = virNetTLSSessionWrite(sock->tlsSession, buf, len);
+    } else {
+        ret = write(sock->fd, buf, len);
+    }
+
+    if (ret < 0) {
+        if (errno == EINTR)
+            goto rewrite;
+        if (errno == EAGAIN)
+            return 0;
+
+        virReportSystemError(errno, "%s",
+                             _("Cannot write data"));
+        return -1;
+    }
+    if (ret == 0) {
+        virReportSystemError(EIO, "%s",
+                             _("End of file while writing data"));
+        return -1;
+    }
+
+    return ret;
+}
+
+
+#if HAVE_SASL
+static ssize_t virNetSocketReadSASL(virNetSocketPtr sock, char *buf, size_t 
len)
+{
+    ssize_t got;
+
+    /* Need to read some more data off the wire */
+    if (sock->saslDecoded == NULL) {
+        ssize_t encodedLen = virNetSASLSessionGetMaxBufSize(sock->saslSession);
+        char *encoded;
+        if (VIR_ALLOC_N(encoded, encodedLen) < 0) {
+            virReportOOMError();
+            return -1;
+        }
+        encodedLen = virNetSocketReadWire(sock, encoded, encodedLen);
+
+        if (encodedLen <= 0) {
+            VIR_FREE(encoded);
+            return encodedLen;
+        }
+
+        if (virNetSASLSessionDecode(sock->saslSession,
+                                    encoded, encodedLen,
+                                    &sock->saslDecoded, 
&sock->saslDecodedLength) < 0) {
+            VIR_FREE(encoded);
+            return -1;
+        }
+        VIR_FREE(encoded);
+
+        sock->saslDecodedOffset = 0;
+    }
+
+    /* Some buffered decoded data to return now */
+    got = sock->saslDecodedLength - sock->saslDecodedOffset;
+
+    if (len > got)
+        len = got;
+
+    memcpy(buf, sock->saslDecoded + sock->saslDecodedOffset, len);
+    sock->saslDecodedOffset += len;
+
+    if (sock->saslDecodedOffset == sock->saslDecodedLength) {
+        sock->saslDecoded = NULL;
+        sock->saslDecodedOffset = sock->saslDecodedLength = 0;
+    }
+
+    return len;
+}
+
+
+static ssize_t virNetSocketWriteSASL(virNetSocketPtr sock, const char *buf, 
size_t len)
+{
+    int ret;
+    size_t tosend = virNetSASLSessionGetMaxBufSize(sock->saslSession);
+
+    /* SASL doesn't neccessarily let us send the whole
+       buffer at once */
+    if (tosend > len)
+        tosend = len;
+
+    /* Not got any pending encoded data, so we need to encode raw stuff */
+    if (sock->saslEncoded == NULL) {
+        if (virNetSASLSessionEncode(sock->saslSession,
+                                    buf, tosend,
+                                    &sock->saslEncoded,
+                                    &sock->saslEncodedLength) < 0)
+            return -1;
+
+        sock->saslEncodedOffset = 0;
+    }
+
+    /* Send some of the encoded stuff out on the wire */
+    ret = virNetSocketWriteWire(sock,
+                                sock->saslEncoded + sock->saslEncodedOffset,
+                                sock->saslEncodedLength - 
sock->saslEncodedOffset);
+
+    if (ret <= 0)
+        return ret; /* -1 error, 0 == egain */
+
+    /* Note how much we sent */
+    sock->saslEncodedOffset += ret;
+
+    /* Sent all encoded, so update raw buffer to indicate completion */
+    if (sock->saslEncodedOffset == sock->saslEncodedLength) {
+        sock->saslEncoded = NULL;
+        sock->saslEncodedOffset = sock->saslEncodedLength = 0;
+
+        /* Mark as complete, so caller detects completion */
+        return tosend;
+    } else {
+        /* Still have stuff pending in saslEncoded buffer.
+         * Pretend to caller that we didn't send any yet.
+         * The caller will then retry with same buffer
+         * shortly, which lets us finish saslEncoded.
+         */
+        return 0;
+    }
+}
+#endif
+
+
+ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len)
+{
+#if HAVE_SASL
+    if (sock->saslSession)
+        return virNetSocketReadSASL(sock, buf, len);
+    else
+#endif
+        return virNetSocketReadWire(sock, buf, len);
+}
+
 ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len)
 {
-    return write(sock->fd, buf, len);
+#if HAVE_SASL
+    if (sock->saslSession)
+        return virNetSocketWriteSASL(sock, buf, len);
+    else
+#endif
+        return virNetSocketWriteWire(sock, buf, len);
 }
 
 
diff --git a/src/rpc/virnetsocket.h b/src/rpc/virnetsocket.h
index c33b2e1..1be423b 100644
--- a/src/rpc/virnetsocket.h
+++ b/src/rpc/virnetsocket.h
@@ -26,6 +26,10 @@
 
 # include "network.h"
 # include "command.h"
+# include "virnettlscontext.h"
+# ifdef HAVE_SASL
+#  include "virnetsaslcontext.h"
+# endif
 
 typedef struct _virNetSocket virNetSocket;
 typedef virNetSocket *virNetSocketPtr;
@@ -83,6 +87,13 @@ int virNetSocketSetBlocking(virNetSocketPtr sock,
 ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len);
 ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len);
 
+void virNetSocketSetTLSSession(virNetSocketPtr sock,
+                               virNetTLSSessionPtr sess);
+# ifdef HAVE_SASL
+void virNetSocketSetSASLSession(virNetSocketPtr sock,
+                                virNetSASLSessionPtr sess);
+# endif
+bool virNetSocketHasCachedData(virNetSocketPtr sock);
 void virNetSocketFree(virNetSocketPtr sock);
 
 const char *virNetSocketLocalAddrString(virNetSocketPtr sock);
-- 
1.7.4

--
libvir-list mailing list
libvir-list@redhat.com
https://www.redhat.com/mailman/listinfo/libvir-list

Reply via email to