This extends the basic virNetSocket APIs to allow them to have
a handle to the TLS/SASL session objects, once established.
This ensures that any data reads/writes are automagically
passed through the TLS/SASL encryption layers if required.
* src/rpc/virnetsocket.c, src/rpc/virnetsocket.h: Wire up
SASL/TLS encryption
---
src/rpc/virnetsocket.c | 274 +++++++++++++++++++++++++++++++++++++++++++++++-
src/rpc/virnetsocket.h | 11 ++
2 files changed, 282 insertions(+), 3 deletions(-)
diff --git a/src/rpc/virnetsocket.c b/src/rpc/virnetsocket.c
index f7d1095..60ee4e5 100644
--- a/src/rpc/virnetsocket.c
+++ b/src/rpc/virnetsocket.c
@@ -27,6 +27,9 @@
#include <sys/socket.h>
#include <unistd.h>
#include <sys/wait.h>
+#ifdef HAVE_NETINET_TCP_H
+# include <netinet/tcp.h>
+#endif
#ifdef HAVE_NETINET_TCP_H
# include <netinet/tcp.h>
@@ -59,6 +62,19 @@ struct _virNetSocket {
virSocketAddr remoteAddr;
char *localAddrStr;
char *remoteAddrStr;
+
+ virNetTLSSessionPtr tlsSession;
+#if HAVE_SASL
+ virNetSASLSessionPtr saslSession;
+
+ const char *saslDecoded;
+ size_t saslDecodedLength;
+ size_t saslDecodedOffset;
+
+ const char *saslEncoded;
+ size_t saslEncodedLength;
+ size_t saslEncodedOffset;
+#endif
};
@@ -417,7 +433,7 @@ error:
}
-#if HAVE_SYS_UN_H
+#ifdef HAVE_SYS_UN_H
int virNetSocketNewConnectUNIX(const char *path,
bool spawnDaemon,
const char *binary,
@@ -633,6 +649,14 @@ void virNetSocketFree(virNetSocketPtr sock)
unlink(sock->localAddr.data.un.sun_path);
#endif
+ /* Make sure it can't send any more I/O during shutdown */
+ if (sock->tlsSession)
+ virNetTLSSessionSetIOCallbacks(sock->tlsSession, NULL, NULL, NULL);
+ virNetTLSSessionFree(sock->tlsSession);
+#if HAVE_SASL
+ virNetSASLSessionFree(sock->saslSession);
+#endif
+
VIR_FORCE_CLOSE(sock->fd);
VIR_FORCE_CLOSE(sock->errfd);
@@ -718,14 +742,258 @@ const char *virNetSocketRemoteAddrString(virNetSocketPtr sock)
return sock->remoteAddrStr;
}
-ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len)
+
+static ssize_t virNetSocketTLSSessionWrite(const char *buf,
+ size_t len,
+ void *opaque)
{
+ virNetSocketPtr sock = opaque;
+ return write(sock->fd, buf, len);
+}
+
+
+static ssize_t virNetSocketTLSSessionRead(char *buf,
+ size_t len,
+ void *opaque)
+{
+ virNetSocketPtr sock = opaque;
return read(sock->fd, buf, len);
}
+
+void virNetSocketSetTLSSession(virNetSocketPtr sock,
+ virNetTLSSessionPtr sess)
+{
+ virNetTLSSessionFree(sock->tlsSession);
+ sock->tlsSession = sess;
+ virNetTLSSessionSetIOCallbacks(sess,
+ virNetSocketTLSSessionWrite,
+ virNetSocketTLSSessionRead,
+ sock);
+ virNetTLSSessionRef(sess);
+}
+
+
+#if HAVE_SASL
+void virNetSocketSetSASLSession(virNetSocketPtr sock,
+ virNetSASLSessionPtr sess)
+{
+ virNetSASLSessionFree(sock->saslSession);
+ sock->saslSession = sess;
+ virNetSASLSessionRef(sess);
+}
+#endif
+
+
+bool virNetSocketHasCachedData(virNetSocketPtr sock ATTRIBUTE_UNUSED)
+{
+#if HAVE_SASL
+ if (sock->saslDecoded)
+ return true;
+#endif
+ return false;
+}
+
+
+static ssize_t virNetSocketReadWire(virNetSocketPtr sock, char *buf, size_t len)
+{
+ char *errout = NULL;
+ ssize_t ret;
+reread:
+ if (sock->tlsSession &&
+ virNetTLSSessionGetHandshakeStatus(sock->tlsSession) ==
+ VIR_NET_TLS_HANDSHAKE_COMPLETE) {
+ ret = virNetTLSSessionRead(sock->tlsSession, buf, len);
+ } else {
+ ret = read(sock->fd, buf, len);
+ }
+
+ if ((ret < 0) && (errno == EINTR))
+ goto reread;
+ if ((ret < 0) && (errno == EAGAIN))
+ return 0;
+
+ if (ret <= 0 &&
+ sock->errfd != -1 &&
+ virFileReadLimFD(sock->errfd, 1024, &errout) >= 0 &&
+ errout != NULL) {
+ size_t elen = strlen(errout);
+ if (elen && errout[elen-1] == '\n')
+ errout[elen-1] = '\0';
+ }
+
+ if (ret < 0) {
+ if (errout)
+ virReportSystemError(errno,
+ _("Cannot recv data: %s"), errout);
+ else
+ virReportSystemError(errno, "%s",
+ _("Cannot recv data"));
+ ret = -1;
+ } else if (ret == 0) {
+ if (errout)
+ virReportSystemError(EIO,
+ _("End of file while reading data: %s"),
errout);
+ else
+ virReportSystemError(EIO, "%s",
+ _("End of file while reading data"));
+ ret = -1;
+ }
+
+ VIR_FREE(errout);
+ return ret;
+}
+
+static ssize_t virNetSocketWriteWire(virNetSocketPtr sock, const char *buf, size_t len)
+{
+ ssize_t ret;
+rewrite:
+ if (sock->tlsSession &&
+ virNetTLSSessionGetHandshakeStatus(sock->tlsSession) ==
+ VIR_NET_TLS_HANDSHAKE_COMPLETE) {
+ ret = virNetTLSSessionWrite(sock->tlsSession, buf, len);
+ } else {
+ ret = write(sock->fd, buf, len);
+ }
+
+ if (ret < 0) {
+ if (errno == EINTR)
+ goto rewrite;
+ if (errno == EAGAIN)
+ return 0;
+
+ virReportSystemError(errno, "%s",
+ _("Cannot write data"));
+ return -1;
+ }
+ if (ret == 0) {
+ virReportSystemError(EIO, "%s",
+ _("End of file while writing data"));
+ return -1;
+ }
+
+ return ret;
+}
+
+
+#if HAVE_SASL
+static ssize_t virNetSocketReadSASL(virNetSocketPtr sock, char *buf, size_t len)
+{
+ ssize_t got;
+
+ /* Need to read some more data off the wire */
+ if (sock->saslDecoded == NULL) {
+ ssize_t encodedLen = virNetSASLSessionGetMaxBufSize(sock->saslSession);
+ char *encoded;
+ if (VIR_ALLOC_N(encoded, encodedLen) < 0) {
+ virReportOOMError();
+ return -1;
+ }
+ encodedLen = virNetSocketReadWire(sock, encoded, encodedLen);
+
+ if (encodedLen <= 0) {
+ VIR_FREE(encoded);
+ return encodedLen;
+ }
+
+ if (virNetSASLSessionDecode(sock->saslSession,
+ encoded, encodedLen,
+ &sock->saslDecoded,
&sock->saslDecodedLength) < 0) {
+ VIR_FREE(encoded);
+ return -1;
+ }
+ VIR_FREE(encoded);
+
+ sock->saslDecodedOffset = 0;
+ }
+
+ /* Some buffered decoded data to return now */
+ got = sock->saslDecodedLength - sock->saslDecodedOffset;
+
+ if (len > got)
+ len = got;
+
+ memcpy(buf, sock->saslDecoded + sock->saslDecodedOffset, len);
+ sock->saslDecodedOffset += len;
+
+ if (sock->saslDecodedOffset == sock->saslDecodedLength) {
+ sock->saslDecoded = NULL;
+ sock->saslDecodedOffset = sock->saslDecodedLength = 0;
+ }
+
+ return len;
+}
+
+
+static ssize_t virNetSocketWriteSASL(virNetSocketPtr sock, const char *buf, size_t len)
+{
+ int ret;
+ size_t tosend = virNetSASLSessionGetMaxBufSize(sock->saslSession);
+
+ /* SASL doesn't neccessarily let us send the whole
+ buffer at once */
+ if (tosend > len)
+ tosend = len;
+
+ /* Not got any pending encoded data, so we need to encode raw stuff */
+ if (sock->saslEncoded == NULL) {
+ if (virNetSASLSessionEncode(sock->saslSession,
+ buf, tosend,
+ &sock->saslEncoded,
+ &sock->saslEncodedLength) < 0)
+ return -1;
+
+ sock->saslEncodedOffset = 0;
+ }
+
+ /* Send some of the encoded stuff out on the wire */
+ ret = virNetSocketWriteWire(sock,
+ sock->saslEncoded + sock->saslEncodedOffset,
+ sock->saslEncodedLength -
sock->saslEncodedOffset);
+
+ if (ret <= 0)
+ return ret; /* -1 error, 0 == egain */
+
+ /* Note how much we sent */
+ sock->saslEncodedOffset += ret;
+
+ /* Sent all encoded, so update raw buffer to indicate completion */
+ if (sock->saslEncodedOffset == sock->saslEncodedLength) {
+ sock->saslEncoded = NULL;
+ sock->saslEncodedOffset = sock->saslEncodedLength = 0;
+
+ /* Mark as complete, so caller detects completion */
+ return tosend;
+ } else {
+ /* Still have stuff pending in saslEncoded buffer.
+ * Pretend to caller that we didn't send any yet.
+ * The caller will then retry with same buffer
+ * shortly, which lets us finish saslEncoded.
+ */
+ return 0;
+ }
+}
+#endif
+
+
+ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len)
+{
+#if HAVE_SASL
+ if (sock->saslSession)
+ return virNetSocketReadSASL(sock, buf, len);
+ else
+#endif
+ return virNetSocketReadWire(sock, buf, len);
+}
+
ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len)
{
- return write(sock->fd, buf, len);
+#if HAVE_SASL
+ if (sock->saslSession)
+ return virNetSocketWriteSASL(sock, buf, len);
+ else
+#endif
+ return virNetSocketWriteWire(sock, buf, len);
}
diff --git a/src/rpc/virnetsocket.h b/src/rpc/virnetsocket.h
index 218fe8f..59ff288 100644
--- a/src/rpc/virnetsocket.h
+++ b/src/rpc/virnetsocket.h
@@ -26,6 +26,10 @@
# include "network.h"
# include "command.h"
+# include "virnettlscontext.h"
+# ifdef HAVE_SASL
+# include "virnetsaslcontext.h"
+# endif
typedef struct _virNetSocket virNetSocket;
typedef virNetSocket *virNetSocketPtr;
@@ -83,6 +87,13 @@ int virNetSocketSetBlocking(virNetSocketPtr sock,
ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len);
ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len);
+void virNetSocketSetTLSSession(virNetSocketPtr sock,
+ virNetTLSSessionPtr sess);
+# ifdef HAVE_SASL
+void virNetSocketSetSASLSession(virNetSocketPtr sock,
+ virNetSASLSessionPtr sess);
+# endif
+bool virNetSocketHasCachedData(virNetSocketPtr sock);
void virNetSocketFree(virNetSocketPtr sock);
const char *virNetSocketLocalAddrString(virNetSocketPtr sock);
--
1.7.4.4