This extends the basic virNetSocket APIs to allow them to have
a handle to the TLS/SASL session objects, once established.
This ensures that any data reads/writes are automagically
passed through the TLS/SASL encryption layers if required.

* src/rpc/virnetsocket.c, src/rpc/virnetsocket.h: Wire up
  SASL/TLS encryption
---
 src/rpc/virnetsocket.c |  211 +++++++++++++++++++++++++++++++++++++++++++++++-
 src/rpc/virnetsocket.h |    7 ++
 2 files changed, 216 insertions(+), 2 deletions(-)

diff --git a/src/rpc/virnetsocket.c b/src/rpc/virnetsocket.c
index 2bcb8fa..08f4f88 100644
--- a/src/rpc/virnetsocket.c
+++ b/src/rpc/virnetsocket.c
@@ -55,6 +55,17 @@ struct _virNetSocket {
     virSocketAddr remoteAddr;
     char *localAddrStr;
     char *remoteAddrStr;
+
+    virNetTLSSessionPtr tlsSession;
+    virNetSASLSessionPtr saslSession;
+
+    const char *saslDecoded;
+    size_t saslDecodedLength;
+    size_t saslDecodedOffset;
+
+    const char *saslEncoded;
+    size_t saslEncodedLength;
+    size_t saslEncodedOffset;
 };
 
 
@@ -564,6 +575,12 @@ void virNetSocketFree(virNetSocketPtr sock)
         sock->localAddr.data.un.sun_path[0] != '\0')
         unlink(sock->localAddr.data.un.sun_path);
 
+    /* Make sure it can't send any more I/O during shutdown */
+    if (sock->tlsSession)
+        virNetTLSSessionSetIOCallbacks(sock->tlsSession, NULL, NULL, NULL);
+    virNetTLSSessionFree(sock->tlsSession);
+    virNetSASLSessionFree(sock->saslSession);
+
     VIR_FORCE_CLOSE(sock->fd);
     VIR_FORCE_CLOSE(sock->errfd);
 
@@ -609,14 +626,204 @@ const char *virNetSocketRemoteAddrString(virNetSocketPtr 
sock)
     return sock->remoteAddrStr;
 }
 
-ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len)
+
+static ssize_t virNetSocketTLSSessionWrite(const char *buf,
+                                           size_t len,
+                                           void *opaque)
 {
+    virNetSocketPtr sock = opaque;
+    return write(sock->fd, buf, len);
+}
+
+
+static ssize_t virNetSocketTLSSessionRead(char *buf,
+                                          size_t len,
+                                          void *opaque)
+{
+    virNetSocketPtr sock = opaque;
     return read(sock->fd, buf, len);
 }
 
+
+void virNetSocketSetTLSSession(virNetSocketPtr sock,
+                               virNetTLSSessionPtr sess)
+{
+    if (sock->tlsSession)
+        virNetTLSSessionFree(sock->tlsSession);
+    sock->tlsSession = sess;
+    virNetTLSSessionSetIOCallbacks(sess,
+                                   virNetSocketTLSSessionWrite,
+                                   virNetSocketTLSSessionRead,
+                                   sock);
+    virNetTLSSessionRef(sess);
+}
+
+void virNetSocketSetSASLSession(virNetSocketPtr sock,
+                                virNetSASLSessionPtr sess)
+{
+    if (sock->saslSession)
+        virNetSASLSessionFree(sock->saslSession);
+    sock->saslSession = sess;
+    virNetSASLSessionRef(sess);
+}
+
+static ssize_t virNetSocketReadWire(virNetSocketPtr sock, char *buf, size_t 
len)
+{
+    ssize_t ret;
+reread:
+    if (sock->tlsSession &&
+        virNetTLSSessionGetHandshakeStatus(sock->tlsSession) ==
+        VIR_NET_TLS_HANDSHAKE_COMPLETE) {
+        ret = virNetTLSSessionRead(sock->tlsSession, buf, len);
+    } else {
+        ret = read(sock->fd, buf, len);
+    }
+
+    if (ret < 0) {
+        if (errno == EINTR)
+            goto reread;
+        if (errno == EAGAIN)
+            return 0;
+
+        virReportSystemError(errno, "%s",
+                             _("Cannot recv data"));
+        return -1;
+    }
+    if (ret == 0) {
+        virReportSystemError(EIO, "%s",
+                             _("End of file while reading data"));
+        return -1;
+    }
+
+    return ret;
+}
+
+static ssize_t virNetSocketWriteWire(virNetSocketPtr sock, const char *buf, 
size_t len)
+{
+    ssize_t ret;
+rewrite:
+    if (sock->tlsSession &&
+        virNetTLSSessionGetHandshakeStatus(sock->tlsSession) ==
+        VIR_NET_TLS_HANDSHAKE_COMPLETE) {
+        ret = virNetTLSSessionWrite(sock->tlsSession, buf, len);
+    } else {
+        ret = write(sock->fd, buf, len);
+    }
+
+    if (ret < 0) {
+        if (errno == EINTR)
+            goto rewrite;
+        if (errno == EAGAIN)
+            return 0;
+
+        virReportSystemError(errno, "%s",
+                             _("Cannot write data"));
+        return -1;
+    }
+    if (ret == 0) {
+        virReportSystemError(EIO, "%s",
+                             _("End of file while writing data"));
+        return -1;
+    }
+
+    return ret;
+}
+
+static ssize_t virNetSocketReadSASL(virNetSocketPtr sock, char *buf, size_t 
len)
+{
+    ssize_t got;
+
+    /* Need to read some more data off the wire */
+    if (sock->saslDecoded == NULL) {
+        char encoded[8192];
+        ssize_t encodedLen = sizeof(encoded);
+        encodedLen = virNetSocketReadWire(sock, encoded, encodedLen);
+
+        if (encodedLen <= 0)
+            return encodedLen;
+
+        if (virNetSASLSessionDecode(sock->saslSession,
+                                    encoded, encodedLen,
+                                    &sock->saslDecoded, 
&sock->saslDecodedLength) < 0)
+            return -1;
+
+        sock->saslDecodedOffset = 0;
+    }
+
+    /* Some buffered decoded data to return now */
+    got = sock->saslDecodedLength - sock->saslDecodedOffset;
+
+    if (len > got)
+        len = got;
+
+    memcpy(buf, sock->saslDecoded + sock->saslDecodedOffset, len);
+    sock->saslDecodedOffset += len;
+
+    if (sock->saslDecodedOffset == sock->saslDecodedLength) {
+        sock->saslDecoded = NULL;
+        sock->saslDecodedOffset = sock->saslDecodedLength = 0;
+    }
+
+    return len;
+}
+
+static ssize_t virNetSocketWriteSASL(virNetSocketPtr sock, const char *buf, 
size_t len)
+{
+    int ret;
+
+    /* Not got any pending encoded data, so we need to encode raw stuff */
+    if (sock->saslEncoded == NULL) {
+        if (virNetSASLSessionEncode(sock->saslSession,
+                                    buf, len,
+                                    &sock->saslEncoded,
+                                    &sock->saslEncodedLength) < 0)
+            return -1;
+
+        sock->saslEncodedOffset = 0;
+    }
+
+    /* Send some of the encoded stuff out on the wire */
+    ret = virNetSocketWriteWire(sock,
+                                sock->saslEncoded + sock->saslEncodedOffset,
+                                sock->saslEncodedLength - 
sock->saslEncodedOffset);
+
+    if (ret <= 0)
+        return ret; /* -1 error, 0 == egain */
+
+    /* Note how much we sent */
+    sock->saslEncodedOffset += ret;
+
+    /* Sent all encoded, so update raw buffer to indicate completion */
+    if (sock->saslEncodedOffset == sock->saslEncodedLength) {
+        sock->saslEncoded = NULL;
+        sock->saslEncodedOffset = sock->saslEncodedLength = 0;
+
+        /* Mark as complete, so caller detects completion */
+        return len;
+    } else {
+        /* Still have stuff pending in saslEncoded buffer.
+         * Pretend to caller that we didn't send any yet.
+         * The caller will then retry with same buffer
+         * shortly, which lets us finish saslEncoded.
+         */
+        return 0;
+    }
+}
+
+ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len)
+{
+    if (sock->saslSession)
+        return virNetSocketReadSASL(sock, buf, len);
+    else
+        return virNetSocketReadWire(sock, buf, len);
+}
+
 ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len)
 {
-    return write(sock->fd, buf, len);
+    if (sock->saslSession)
+        return virNetSocketWriteSASL(sock, buf, len);
+    else
+        return virNetSocketWriteWire(sock, buf, len);
 }
 
 
diff --git a/src/rpc/virnetsocket.h b/src/rpc/virnetsocket.h
index 4441848..94a5f30 100644
--- a/src/rpc/virnetsocket.h
+++ b/src/rpc/virnetsocket.h
@@ -26,6 +26,8 @@
 
 # include "network.h"
 # include "command.h"
+# include "virnettlscontext.h"
+# include "virnetsaslcontext.h"
 
 typedef struct _virNetSocket virNetSocket;
 typedef virNetSocket *virNetSocketPtr;
@@ -76,6 +78,11 @@ bool virNetSocketIsLocal(virNetSocketPtr sock);
 ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len);
 ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len);
 
+void virNetSocketSetTLSSession(virNetSocketPtr sock,
+                               virNetTLSSessionPtr sess);
+void virNetSocketSetSASLSession(virNetSocketPtr sock,
+                                virNetSASLSessionPtr sess);
+
 void virNetSocketFree(virNetSocketPtr sock);
 
 const char *virNetSocketLocalAddrString(virNetSocketPtr sock);
-- 
1.7.2.3

--
libvir-list mailing list
libvir-list@redhat.com
https://www.redhat.com/mailman/listinfo/libvir-list

Reply via email to