This extends the basic virNetSocket APIs to allow them to have a handle to the TLS/SASL session objects, once established. This ensures that any data reads/writes are automagically passed through the TLS/SASL encryption layers if required.
* src/rpc/virnetsocket.c, src/rpc/virnetsocket.h: Wire up SASL/TLS encryption --- src/rpc/virnetsocket.c | 211 +++++++++++++++++++++++++++++++++++++++++++++++- src/rpc/virnetsocket.h | 7 ++ 2 files changed, 216 insertions(+), 2 deletions(-) diff --git a/src/rpc/virnetsocket.c b/src/rpc/virnetsocket.c index 2bcb8fa..08f4f88 100644 --- a/src/rpc/virnetsocket.c +++ b/src/rpc/virnetsocket.c @@ -55,6 +55,17 @@ struct _virNetSocket { virSocketAddr remoteAddr; char *localAddrStr; char *remoteAddrStr; + + virNetTLSSessionPtr tlsSession; + virNetSASLSessionPtr saslSession; + + const char *saslDecoded; + size_t saslDecodedLength; + size_t saslDecodedOffset; + + const char *saslEncoded; + size_t saslEncodedLength; + size_t saslEncodedOffset; }; @@ -564,6 +575,12 @@ void virNetSocketFree(virNetSocketPtr sock) sock->localAddr.data.un.sun_path[0] != '\0') unlink(sock->localAddr.data.un.sun_path); + /* Make sure it can't send any more I/O during shutdown */ + if (sock->tlsSession) + virNetTLSSessionSetIOCallbacks(sock->tlsSession, NULL, NULL, NULL); + virNetTLSSessionFree(sock->tlsSession); + virNetSASLSessionFree(sock->saslSession); + VIR_FORCE_CLOSE(sock->fd); VIR_FORCE_CLOSE(sock->errfd); @@ -609,14 +626,204 @@ const char *virNetSocketRemoteAddrString(virNetSocketPtr sock) return sock->remoteAddrStr; } -ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len) + +static ssize_t virNetSocketTLSSessionWrite(const char *buf, + size_t len, + void *opaque) { + virNetSocketPtr sock = opaque; + return write(sock->fd, buf, len); +} + + +static ssize_t virNetSocketTLSSessionRead(char *buf, + size_t len, + void *opaque) +{ + virNetSocketPtr sock = opaque; return read(sock->fd, buf, len); } + +void virNetSocketSetTLSSession(virNetSocketPtr sock, + virNetTLSSessionPtr sess) +{ + if (sock->tlsSession) + virNetTLSSessionFree(sock->tlsSession); + sock->tlsSession = sess; + virNetTLSSessionSetIOCallbacks(sess, + virNetSocketTLSSessionWrite, + virNetSocketTLSSessionRead, + sock); + virNetTLSSessionRef(sess); +} + +void virNetSocketSetSASLSession(virNetSocketPtr sock, + virNetSASLSessionPtr sess) +{ + if (sock->saslSession) + virNetSASLSessionFree(sock->saslSession); + sock->saslSession = sess; + virNetSASLSessionRef(sess); +} + +static ssize_t virNetSocketReadWire(virNetSocketPtr sock, char *buf, size_t len) +{ + ssize_t ret; +reread: + if (sock->tlsSession && + virNetTLSSessionGetHandshakeStatus(sock->tlsSession) == + VIR_NET_TLS_HANDSHAKE_COMPLETE) { + ret = virNetTLSSessionRead(sock->tlsSession, buf, len); + } else { + ret = read(sock->fd, buf, len); + } + + if (ret < 0) { + if (errno == EINTR) + goto reread; + if (errno == EAGAIN) + return 0; + + virReportSystemError(errno, "%s", + _("Cannot recv data")); + return -1; + } + if (ret == 0) { + virReportSystemError(EIO, "%s", + _("End of file while reading data")); + return -1; + } + + return ret; +} + +static ssize_t virNetSocketWriteWire(virNetSocketPtr sock, const char *buf, size_t len) +{ + ssize_t ret; +rewrite: + if (sock->tlsSession && + virNetTLSSessionGetHandshakeStatus(sock->tlsSession) == + VIR_NET_TLS_HANDSHAKE_COMPLETE) { + ret = virNetTLSSessionWrite(sock->tlsSession, buf, len); + } else { + ret = write(sock->fd, buf, len); + } + + if (ret < 0) { + if (errno == EINTR) + goto rewrite; + if (errno == EAGAIN) + return 0; + + virReportSystemError(errno, "%s", + _("Cannot write data")); + return -1; + } + if (ret == 0) { + virReportSystemError(EIO, "%s", + _("End of file while writing data")); + return -1; + } + + return ret; +} + +static ssize_t virNetSocketReadSASL(virNetSocketPtr sock, char *buf, size_t len) +{ + ssize_t got; + + /* Need to read some more data off the wire */ + if (sock->saslDecoded == NULL) { + char encoded[8192]; + ssize_t encodedLen = sizeof(encoded); + encodedLen = virNetSocketReadWire(sock, encoded, encodedLen); + + if (encodedLen <= 0) + return encodedLen; + + if (virNetSASLSessionDecode(sock->saslSession, + encoded, encodedLen, + &sock->saslDecoded, &sock->saslDecodedLength) < 0) + return -1; + + sock->saslDecodedOffset = 0; + } + + /* Some buffered decoded data to return now */ + got = sock->saslDecodedLength - sock->saslDecodedOffset; + + if (len > got) + len = got; + + memcpy(buf, sock->saslDecoded + sock->saslDecodedOffset, len); + sock->saslDecodedOffset += len; + + if (sock->saslDecodedOffset == sock->saslDecodedLength) { + sock->saslDecoded = NULL; + sock->saslDecodedOffset = sock->saslDecodedLength = 0; + } + + return len; +} + +static ssize_t virNetSocketWriteSASL(virNetSocketPtr sock, const char *buf, size_t len) +{ + int ret; + + /* Not got any pending encoded data, so we need to encode raw stuff */ + if (sock->saslEncoded == NULL) { + if (virNetSASLSessionEncode(sock->saslSession, + buf, len, + &sock->saslEncoded, + &sock->saslEncodedLength) < 0) + return -1; + + sock->saslEncodedOffset = 0; + } + + /* Send some of the encoded stuff out on the wire */ + ret = virNetSocketWriteWire(sock, + sock->saslEncoded + sock->saslEncodedOffset, + sock->saslEncodedLength - sock->saslEncodedOffset); + + if (ret <= 0) + return ret; /* -1 error, 0 == egain */ + + /* Note how much we sent */ + sock->saslEncodedOffset += ret; + + /* Sent all encoded, so update raw buffer to indicate completion */ + if (sock->saslEncodedOffset == sock->saslEncodedLength) { + sock->saslEncoded = NULL; + sock->saslEncodedOffset = sock->saslEncodedLength = 0; + + /* Mark as complete, so caller detects completion */ + return len; + } else { + /* Still have stuff pending in saslEncoded buffer. + * Pretend to caller that we didn't send any yet. + * The caller will then retry with same buffer + * shortly, which lets us finish saslEncoded. + */ + return 0; + } +} + +ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len) +{ + if (sock->saslSession) + return virNetSocketReadSASL(sock, buf, len); + else + return virNetSocketReadWire(sock, buf, len); +} + ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len) { - return write(sock->fd, buf, len); + if (sock->saslSession) + return virNetSocketWriteSASL(sock, buf, len); + else + return virNetSocketWriteWire(sock, buf, len); } diff --git a/src/rpc/virnetsocket.h b/src/rpc/virnetsocket.h index 4441848..94a5f30 100644 --- a/src/rpc/virnetsocket.h +++ b/src/rpc/virnetsocket.h @@ -26,6 +26,8 @@ # include "network.h" # include "command.h" +# include "virnettlscontext.h" +# include "virnetsaslcontext.h" typedef struct _virNetSocket virNetSocket; typedef virNetSocket *virNetSocketPtr; @@ -76,6 +78,11 @@ bool virNetSocketIsLocal(virNetSocketPtr sock); ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len); ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len); +void virNetSocketSetTLSSession(virNetSocketPtr sock, + virNetTLSSessionPtr sess); +void virNetSocketSetSASLSession(virNetSocketPtr sock, + virNetSASLSessionPtr sess); + void virNetSocketFree(virNetSocketPtr sock); const char *virNetSocketLocalAddrString(virNetSocketPtr sock); -- 1.7.2.3 -- libvir-list mailing list libvir-list@redhat.com https://www.redhat.com/mailman/listinfo/libvir-list