Some TLX interacting functions can be called within two or more
threads with the same pointer. Therefore we need to protect
virNetTLSSessionPtr with mutex to avoid non-consistent states.
---
src/rpc/virnettlscontext.c | 41 +++++++++++++++++++++++++++++++++++++++--
1 files changed, 39 insertions(+), 2 deletions(-)
diff --git a/src/rpc/virnettlscontext.c b/src/rpc/virnettlscontext.c
index bde4e7a..a0f7a3f 100644
--- a/src/rpc/virnettlscontext.c
+++ b/src/rpc/virnettlscontext.c
@@ -35,6 +35,7 @@
#include "util.h"
#include "logging.h"
#include "configmake.h"
+#include "threads.h"
#define DH_BITS 1024
@@ -63,6 +64,7 @@ struct _virNetTLSContext {
};
struct _virNetTLSSession {
+ virMutex lock;
int refs;
bool handshakeComplete;
@@ -1083,6 +1085,16 @@ void virNetTLSContextFree(virNetTLSContextPtr ctxt)
+static void virNetTLSSessionLock(virNetTLSSessionPtr session)
+{
+ virMutexLock(&session->lock);
+}
+
+static void virNetTLSSessionUnlock(virNetTLSSessionPtr session)
+{
+ virMutexUnlock(&session->lock);
+}
+
static ssize_t
virNetTLSSessionPush(void *opaque, const void *buf, size_t len)
{
@@ -1124,6 +1136,9 @@ virNetTLSSessionPtr virNetTLSSessionNew(virNetTLSContextPtr ctxt,
return NULL;
}
+ if (virMutexInit(&sess->lock) < 0)
+ goto error;
+
sess->refs = 1;
if (hostname &&
!(sess->hostname = strdup(hostname))) {
@@ -1184,7 +1199,9 @@ error:
void virNetTLSSessionRef(virNetTLSSessionPtr sess)
{
+ virNetTLSSessionLock(sess);
sess->refs++;
+ virNetTLSSessionUnlock(sess);
}
void virNetTLSSessionSetIOCallbacks(virNetTLSSessionPtr sess,
@@ -1192,9 +1209,11 @@ void virNetTLSSessionSetIOCallbacks(virNetTLSSessionPtr sess,
virNetTLSSessionReadFunc readFunc,
void *opaque)
{
+ virNetTLSSessionLock(sess);
sess->writeFunc = writeFunc;
sess->readFunc = readFunc;
sess->opaque = opaque;
+ virNetTLSSessionUnlock(sess);
}
@@ -1202,7 +1221,10 @@ ssize_t virNetTLSSessionWrite(virNetTLSSessionPtr sess,
const char *buf, size_t len)
{
ssize_t ret;
+
+ virNetTLSSessionLock(sess);
ret = gnutls_record_send(sess->session, buf, len);
+ virNetTLSSessionUnlock(sess);
if (ret >= 0)
return ret;
@@ -1230,7 +1252,9 @@ ssize_t virNetTLSSessionRead(virNetTLSSessionPtr sess,
{
ssize_t ret;
+ virNetTLSSessionLock(sess);
ret = gnutls_record_recv(sess->session, buf, len);
+ virNetTLSSessionUnlock(sess);
if (ret >= 0)
return ret;
@@ -1253,15 +1277,19 @@ ssize_t virNetTLSSessionRead(virNetTLSSessionPtr sess,
int virNetTLSSessionHandshake(virNetTLSSessionPtr sess)
{
VIR_DEBUG("sess=%p", sess);
+ virNetTLSSessionLock(sess);
int ret = gnutls_handshake(sess->session);
VIR_DEBUG("Ret=%d", ret);
if (ret == 0) {
sess->handshakeComplete = true;
VIR_DEBUG("Handshake is complete");
+ virNetTLSSessionUnlock(sess);
return 0;
}
- if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN)
+ if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
+ virNetTLSSessionUnlock(sess);
return 1;
+ }
#if 0
PROBE(CLIENT_TLS_FAIL, "fd=%d",
@@ -1271,6 +1299,7 @@ int virNetTLSSessionHandshake(virNetTLSSessionPtr sess)
virNetError(VIR_ERR_AUTH_FAILED,
_("TLS handshake failed %s"),
gnutls_strerror(ret));
+ virNetTLSSessionUnlock(sess);
return -1;
}
@@ -1290,12 +1319,15 @@ int virNetTLSSessionGetKeySize(virNetTLSSessionPtr sess)
gnutls_cipher_algorithm_t cipher;
int ssf;
+ virNetTLSSessionLock(sess);
cipher = gnutls_cipher_get(sess->session);
if (!(ssf = gnutls_cipher_get_key_size(cipher))) {
virNetError(VIR_ERR_INTERNAL_ERROR, "%s",
_("invalid cipher size for TLS session"));
+ virNetTLSSessionUnlock(sess);
return -1;
}
+ virNetTLSSessionUnlock(sess);
return ssf;
}
@@ -1306,11 +1338,16 @@ void virNetTLSSessionFree(virNetTLSSessionPtr sess)
if (!sess)
return;
+ virNetTLSSessionLock(sess);
sess->refs--;
- if (sess->refs > 0)
+ if (sess->refs > 0) {
+ virNetTLSSessionUnlock(sess);
return;
+ }
VIR_FREE(sess->hostname);
gnutls_deinit(sess->session);
+ virNetTLSSessionUnlock(sess);
+ virMutexDestroy(&sess->lock);
VIR_FREE(sess);
}
--
1.7.5.rc3