This extends the basic virNetSocket APIs to allow them to have
a handle to the TLS/SASL session objects, once established.
This ensures that any data reads/writes are automagically
passed through the TLS/SASL encryption layers if required.
* src/rpc/virnetsocket.c, src/rpc/virnetsocket.h: Wire up
SASL/TLS encryption
---
src/rpc/virnetsocket.c | 211 +++++++++++++++++++++++++++++++++++++++++++++++-
src/rpc/virnetsocket.h | 7 ++
2 files changed, 216 insertions(+), 2 deletions(-)
diff --git a/src/rpc/virnetsocket.c b/src/rpc/virnetsocket.c
index 2bcb8fa..08f4f88 100644
--- a/src/rpc/virnetsocket.c
+++ b/src/rpc/virnetsocket.c
@@ -55,6 +55,17 @@ struct _virNetSocket {
virSocketAddr remoteAddr;
char *localAddrStr;
char *remoteAddrStr;
+
+ virNetTLSSessionPtr tlsSession;
+ virNetSASLSessionPtr saslSession;
+
+ const char *saslDecoded;
+ size_t saslDecodedLength;
+ size_t saslDecodedOffset;
+
+ const char *saslEncoded;
+ size_t saslEncodedLength;
+ size_t saslEncodedOffset;
};
@@ -564,6 +575,12 @@ void virNetSocketFree(virNetSocketPtr sock)
sock->localAddr.data.un.sun_path[0] != '\0')
unlink(sock->localAddr.data.un.sun_path);
+ /* Make sure it can't send any more I/O during shutdown */
+ if (sock->tlsSession)
+ virNetTLSSessionSetIOCallbacks(sock->tlsSession, NULL, NULL, NULL);
+ virNetTLSSessionFree(sock->tlsSession);
+ virNetSASLSessionFree(sock->saslSession);
+
VIR_FORCE_CLOSE(sock->fd);
VIR_FORCE_CLOSE(sock->errfd);
@@ -609,14 +626,204 @@ const char *virNetSocketRemoteAddrString(virNetSocketPtr sock)
return sock->remoteAddrStr;
}
-ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len)
+
+static ssize_t virNetSocketTLSSessionWrite(const char *buf,
+ size_t len,
+ void *opaque)
{
+ virNetSocketPtr sock = opaque;
+ return write(sock->fd, buf, len);
+}
+
+
+static ssize_t virNetSocketTLSSessionRead(char *buf,
+ size_t len,
+ void *opaque)
+{
+ virNetSocketPtr sock = opaque;
return read(sock->fd, buf, len);
}
+
+void virNetSocketSetTLSSession(virNetSocketPtr sock,
+ virNetTLSSessionPtr sess)
+{
+ if (sock->tlsSession)
+ virNetTLSSessionFree(sock->tlsSession);
+ sock->tlsSession = sess;
+ virNetTLSSessionSetIOCallbacks(sess,
+ virNetSocketTLSSessionWrite,
+ virNetSocketTLSSessionRead,
+ sock);
+ virNetTLSSessionRef(sess);
+}
+
+void virNetSocketSetSASLSession(virNetSocketPtr sock,
+ virNetSASLSessionPtr sess)
+{
+ if (sock->saslSession)
+ virNetSASLSessionFree(sock->saslSession);
+ sock->saslSession = sess;
+ virNetSASLSessionRef(sess);
+}
+
+static ssize_t virNetSocketReadWire(virNetSocketPtr sock, char *buf, size_t len)
+{
+ ssize_t ret;
+reread:
+ if (sock->tlsSession &&
+ virNetTLSSessionGetHandshakeStatus(sock->tlsSession) ==
+ VIR_NET_TLS_HANDSHAKE_COMPLETE) {
+ ret = virNetTLSSessionRead(sock->tlsSession, buf, len);
+ } else {
+ ret = read(sock->fd, buf, len);
+ }
+
+ if (ret < 0) {
+ if (errno == EINTR)
+ goto reread;
+ if (errno == EAGAIN)
+ return 0;
+
+ virReportSystemError(errno, "%s",
+ _("Cannot recv data"));
+ return -1;
+ }
+ if (ret == 0) {
+ virReportSystemError(EIO, "%s",
+ _("End of file while reading data"));
+ return -1;
+ }
+
+ return ret;
+}
+
+static ssize_t virNetSocketWriteWire(virNetSocketPtr sock, const char *buf, size_t len)
+{
+ ssize_t ret;
+rewrite:
+ if (sock->tlsSession &&
+ virNetTLSSessionGetHandshakeStatus(sock->tlsSession) ==
+ VIR_NET_TLS_HANDSHAKE_COMPLETE) {
+ ret = virNetTLSSessionWrite(sock->tlsSession, buf, len);
+ } else {
+ ret = write(sock->fd, buf, len);
+ }
+
+ if (ret < 0) {
+ if (errno == EINTR)
+ goto rewrite;
+ if (errno == EAGAIN)
+ return 0;
+
+ virReportSystemError(errno, "%s",
+ _("Cannot write data"));
+ return -1;
+ }
+ if (ret == 0) {
+ virReportSystemError(EIO, "%s",
+ _("End of file while writing data"));
+ return -1;
+ }
+
+ return ret;
+}
+
+static ssize_t virNetSocketReadSASL(virNetSocketPtr sock, char *buf, size_t len)
+{
+ ssize_t got;
+
+ /* Need to read some more data off the wire */
+ if (sock->saslDecoded == NULL) {
+ char encoded[8192];
+ ssize_t encodedLen = sizeof(encoded);
+ encodedLen = virNetSocketReadWire(sock, encoded, encodedLen);
+
+ if (encodedLen <= 0)
+ return encodedLen;
+
+ if (virNetSASLSessionDecode(sock->saslSession,
+ encoded, encodedLen,
+ &sock->saslDecoded,
&sock->saslDecodedLength) < 0)
+ return -1;
+
+ sock->saslDecodedOffset = 0;
+ }
+
+ /* Some buffered decoded data to return now */
+ got = sock->saslDecodedLength - sock->saslDecodedOffset;
+
+ if (len > got)
+ len = got;
+
+ memcpy(buf, sock->saslDecoded + sock->saslDecodedOffset, len);
+ sock->saslDecodedOffset += len;
+
+ if (sock->saslDecodedOffset == sock->saslDecodedLength) {
+ sock->saslDecoded = NULL;
+ sock->saslDecodedOffset = sock->saslDecodedLength = 0;
+ }
+
+ return len;
+}
+
+static ssize_t virNetSocketWriteSASL(virNetSocketPtr sock, const char *buf, size_t len)
+{
+ int ret;
+
+ /* Not got any pending encoded data, so we need to encode raw stuff */
+ if (sock->saslEncoded == NULL) {
+ if (virNetSASLSessionEncode(sock->saslSession,
+ buf, len,
+ &sock->saslEncoded,
+ &sock->saslEncodedLength) < 0)
+ return -1;
+
+ sock->saslEncodedOffset = 0;
+ }
+
+ /* Send some of the encoded stuff out on the wire */
+ ret = virNetSocketWriteWire(sock,
+ sock->saslEncoded + sock->saslEncodedOffset,
+ sock->saslEncodedLength -
sock->saslEncodedOffset);
+
+ if (ret <= 0)
+ return ret; /* -1 error, 0 == egain */
+
+ /* Note how much we sent */
+ sock->saslEncodedOffset += ret;
+
+ /* Sent all encoded, so update raw buffer to indicate completion */
+ if (sock->saslEncodedOffset == sock->saslEncodedLength) {
+ sock->saslEncoded = NULL;
+ sock->saslEncodedOffset = sock->saslEncodedLength = 0;
+
+ /* Mark as complete, so caller detects completion */
+ return len;
+ } else {
+ /* Still have stuff pending in saslEncoded buffer.
+ * Pretend to caller that we didn't send any yet.
+ * The caller will then retry with same buffer
+ * shortly, which lets us finish saslEncoded.
+ */
+ return 0;
+ }
+}
+
+ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len)
+{
+ if (sock->saslSession)
+ return virNetSocketReadSASL(sock, buf, len);
+ else
+ return virNetSocketReadWire(sock, buf, len);
+}
+
ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len)
{
- return write(sock->fd, buf, len);
+ if (sock->saslSession)
+ return virNetSocketWriteSASL(sock, buf, len);
+ else
+ return virNetSocketWriteWire(sock, buf, len);
}
diff --git a/src/rpc/virnetsocket.h b/src/rpc/virnetsocket.h
index 4441848..94a5f30 100644
--- a/src/rpc/virnetsocket.h
+++ b/src/rpc/virnetsocket.h
@@ -26,6 +26,8 @@
# include "network.h"
# include "command.h"
+# include "virnettlscontext.h"
+# include "virnetsaslcontext.h"
typedef struct _virNetSocket virNetSocket;
typedef virNetSocket *virNetSocketPtr;
@@ -76,6 +78,11 @@ bool virNetSocketIsLocal(virNetSocketPtr sock);
ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len);
ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len);
+void virNetSocketSetTLSSession(virNetSocketPtr sock,
+ virNetTLSSessionPtr sess);
+void virNetSocketSetSASLSession(virNetSocketPtr sock,
+ virNetSASLSessionPtr sess);
+
void virNetSocketFree(virNetSocketPtr sock);
const char *virNetSocketLocalAddrString(virNetSocketPtr sock);
--
1.7.2.3