[libvirt] [PATCH 1/2] SASL: Introduce session mutex

Some of SASL interacting functions can be called within two or more threads with the same pointer. Therefore we need to protect virNetSASLSessionPtr with mutex to avoid non-consistent states. --- src/rpc/virnetsaslcontext.c | 67 +++++++++++++++++++++++++++++++++++++++++- 1 files changed, 65 insertions(+), 2 deletions(-) diff --git a/src/rpc/virnetsaslcontext.c b/src/rpc/virnetsaslcontext.c index 6b2a883..ef91b9d 100644 --- a/src/rpc/virnetsaslcontext.c +++ b/src/rpc/virnetsaslcontext.c @@ -28,6 +28,7 @@ #include "virterror_internal.h" #include "memory.h" #include "logging.h" +#include "threads.h" #define VIR_FROM_THIS VIR_FROM_RPC #define virNetError(code, ...) \ @@ -41,6 +42,7 @@ struct _virNetSASLContext { }; struct _virNetSASLSession { + virMutex lock; sasl_conn_t *conn; int refs; size_t maxbufsize; @@ -145,6 +147,16 @@ void virNetSASLContextFree(virNetSASLContextPtr ctxt) VIR_FREE(ctxt); } +static void virNetSASLSessionLock(virNetSASLSessionPtr session) +{ + virMutexLock(&session->lock); +} + +static void virNetSASLSessionUnlock(virNetSASLSessionPtr session) +{ + virMutexUnlock(&session->lock); +} + virNetSASLSessionPtr virNetSASLSessionNewClient(virNetSASLContextPtr ctxt ATTRIBUTE_UNUSED, const char *service, const char *hostname, @@ -160,6 +172,9 @@ virNetSASLSessionPtr virNetSASLSessionNewClient(virNetSASLContextPtr ctxt ATTRIB goto cleanup; } + if (virMutexInit(&sasl->lock) < 0) + goto cleanup; + sasl->refs = 1; /* Arbitrary size for amount of data we can encode in a single block */ sasl->maxbufsize = 1 << 16; @@ -198,6 +213,9 @@ virNetSASLSessionPtr virNetSASLSessionNewServer(virNetSASLContextPtr ctxt ATTRIB goto cleanup; } + if (virMutexInit(&sasl->lock) < 0) + goto cleanup; + sasl->refs = 1; /* Arbitrary size for amount of data we can encode in a single block */ sasl->maxbufsize = 1 << 16; @@ -226,7 +244,9 @@ cleanup: void virNetSASLSessionRef(virNetSASLSessionPtr sasl) { + virNetSASLSessionLock(sasl); sasl->refs++; + virNetSASLSessionUnlock(sasl); } int virNetSASLSessionExtKeySize(virNetSASLSessionPtr sasl, @@ -234,13 +254,16 @@ int virNetSASLSessionExtKeySize(virNetSASLSessionPtr sasl, { int err; + virNetSASLSessionLock(sasl); err = sasl_setprop(sasl->conn, SASL_SSF_EXTERNAL, &ssf); if (err != SASL_OK) { virNetError(VIR_ERR_INTERNAL_ERROR, _("cannot set external SSF %d (%s)"), err, sasl_errstring(err, NULL, NULL)); + virNetSASLSessionUnlock(sasl); return -1; } + virNetSASLSessionUnlock(sasl); return 0; } @@ -249,13 +272,16 @@ const char *virNetSASLSessionGetIdentity(virNetSASLSessionPtr sasl) const void *val; int err; + virNetSASLSessionLock(sasl); err = sasl_getprop(sasl->conn, SASL_USERNAME, &val); if (err != SASL_OK) { virNetError(VIR_ERR_AUTH_FAILED, _("cannot query SASL username on connection %d (%s)"), err, sasl_errstring(err, NULL, NULL)); + virNetSASLSessionUnlock(sasl); return NULL; } + virNetSASLSessionUnlock(sasl); if (val == NULL) { virNetError(VIR_ERR_AUTH_FAILED, _("no client username was found")); @@ -272,13 +298,17 @@ int virNetSASLSessionGetKeySize(virNetSASLSessionPtr sasl) int err; int ssf; const void *val; + + virNetSASLSessionLock(sasl); err = sasl_getprop(sasl->conn, SASL_SSF, &val); if (err != SASL_OK) { virNetError(VIR_ERR_AUTH_FAILED, _("cannot query SASL ssf on connection %d (%s)"), err, sasl_errstring(err, NULL, NULL)); + virNetSASLSessionUnlock(sasl); return -1; } + virNetSASLSessionUnlock(sasl); ssf = *(const int *)val; return ssf; } @@ -291,6 +321,7 @@ int virNetSASLSessionSecProps(virNetSASLSessionPtr sasl, sasl_security_properties_t secprops; int err; + virNetSASLSessionLock(sasl); VIR_DEBUG("minSSF=%d maxSSF=%d allowAnonymous=%d maxbufsize=%zu", minSSF, maxSSF, allowAnonymous, sasl->maxbufsize); @@ -307,8 +338,10 @@ int virNetSASLSessionSecProps(virNetSASLSessionPtr sasl, virNetError(VIR_ERR_INTERNAL_ERROR, _("cannot set security props %d (%s)"), err, sasl_errstring(err, NULL, NULL)); + virNetSASLSessionUnlock(sasl); return -1; } + virNetSASLSessionUnlock(sasl); return 0; } @@ -319,17 +352,20 @@ static int virNetSASLSessionUpdateBufSize(virNetSASLSessionPtr sasl) unsigned *maxbufsize; int err; + virNetSASLSessionLock(sasl); err = sasl_getprop(sasl->conn, SASL_MAXOUTBUF, (const void **)&maxbufsize); if (err != SASL_OK) { virNetError(VIR_ERR_INTERNAL_ERROR, _("cannot get security props %d (%s)"), err, sasl_errstring(err, NULL, NULL)); + virNetSASLSessionUnlock(sasl); return -1; } VIR_DEBUG("Negotiated bufsize is %u vs requested size %zu", *maxbufsize, sasl->maxbufsize); sasl->maxbufsize = *maxbufsize; + virNetSASLSessionUnlock(sasl); return 0; } @@ -339,6 +375,7 @@ char *virNetSASLSessionListMechanisms(virNetSASLSessionPtr sasl) char *ret; int err; + virNetSASLSessionLock(sasl); err = sasl_listmech(sasl->conn, NULL, /* Don't need to set user */ "", /* Prefix */ @@ -351,8 +388,10 @@ char *virNetSASLSessionListMechanisms(virNetSASLSessionPtr sasl) virNetError(VIR_ERR_INTERNAL_ERROR, _("cannot list SASL mechanisms %d (%s)"), err, sasl_errdetail(sasl->conn)); + virNetSASLSessionUnlock(sasl); return NULL; } + virNetSASLSessionUnlock(sasl); if (!(ret = strdup(mechlist))) { virReportOOMError(); return NULL; @@ -373,6 +412,7 @@ int virNetSASLSessionClientStart(virNetSASLSessionPtr sasl, VIR_DEBUG("sasl=%p mechlist=%s prompt_need=%p clientout=%p clientoutlen=%p mech=%p", sasl, mechlist, prompt_need, clientout, clientoutlen, mech); + virNetSASLSessionLock(sasl); int err = sasl_client_start(sasl->conn, mechlist, prompt_need, @@ -380,6 +420,7 @@ int virNetSASLSessionClientStart(virNetSASLSessionPtr sasl, &outlen, mech); + virNetSASLSessionUnlock(sasl); *clientoutlen = outlen; switch (err) { @@ -414,12 +455,14 @@ int virNetSASLSessionClientStep(virNetSASLSessionPtr sasl, VIR_DEBUG("sasl=%p serverin=%s serverinlen=%zu prompt_need=%p clientout=%p clientoutlen=%p", sasl, serverin, serverinlen, prompt_need, clientout, clientoutlen); + virNetSASLSessionLock(sasl); int err = sasl_client_step(sasl->conn, serverin, inlen, prompt_need, clientout, &outlen); + virNetSASLSessionUnlock(sasl); *clientoutlen = outlen; switch (err) { @@ -449,6 +492,8 @@ int virNetSASLSessionServerStart(virNetSASLSessionPtr sasl, { unsigned inlen = clientinlen; unsigned outlen = 0; + + virNetSASLSessionLock(sasl); int err = sasl_server_start(sasl->conn, mechname, clientin, @@ -456,6 +501,7 @@ int virNetSASLSessionServerStart(virNetSASLSessionPtr sasl, serverout, &outlen); + virNetSASLSessionUnlock(sasl); *serveroutlen = outlen; switch (err) { @@ -486,12 +532,14 @@ int virNetSASLSessionServerStep(virNetSASLSessionPtr sasl, unsigned inlen = clientinlen; unsigned outlen = 0; + virNetSASLSessionLock(sasl); int err = sasl_server_step(sasl->conn, clientin, inlen, serverout, &outlen); + virNetSASLSessionUnlock(sasl); *serveroutlen = outlen; switch (err) { @@ -514,7 +562,11 @@ int virNetSASLSessionServerStep(virNetSASLSessionPtr sasl, size_t virNetSASLSessionGetMaxBufSize(virNetSASLSessionPtr sasl) { - return sasl->maxbufsize; + size_t ret; + virNetSASLSessionLock(sasl); + ret = sasl->maxbufsize; + virNetSASLSessionUnlock(sasl); + return ret; } ssize_t virNetSASLSessionEncode(virNetSASLSessionPtr sasl, @@ -534,6 +586,7 @@ ssize_t virNetSASLSessionEncode(virNetSASLSessionPtr sasl, return -1; } + virNetSASLSessionLock(sasl); err = sasl_encode(sasl->conn, input, inlen, @@ -545,8 +598,10 @@ ssize_t virNetSASLSessionEncode(virNetSASLSessionPtr sasl, virNetError(VIR_ERR_INTERNAL_ERROR, _("failed to encode SASL data: %d (%s)"), err, sasl_errstring(err, NULL, NULL)); + virNetSASLSessionUnlock(sasl); return -1; } + virNetSASLSessionUnlock(sasl); return 0; } @@ -567,6 +622,7 @@ ssize_t virNetSASLSessionDecode(virNetSASLSessionPtr sasl, return -1; } + virNetSASLSessionLock(sasl); err = sasl_decode(sasl->conn, input, inlen, @@ -577,8 +633,10 @@ ssize_t virNetSASLSessionDecode(virNetSASLSessionPtr sasl, virNetError(VIR_ERR_INTERNAL_ERROR, _("failed to decode SASL data: %d (%s)"), err, sasl_errstring(err, NULL, NULL)); + virNetSASLSessionUnlock(sasl); return -1; } + virNetSASLSessionUnlock(sasl); return 0; } @@ -587,12 +645,17 @@ void virNetSASLSessionFree(virNetSASLSessionPtr sasl) if (!sasl) return; + virNetSASLSessionLock(sasl); sasl->refs--; - if (sasl->refs > 0) + if (sasl->refs > 0) { + virNetSASLSessionUnlock(sasl); return; + } if (sasl->conn) sasl_dispose(&sasl->conn); + virNetSASLSessionUnlock(sasl); + virMutexDestroy(&sasl->lock); VIR_FREE(sasl); } -- 1.7.5.rc3

Some TLX interacting functions can be called within two or more threads with the same pointer. Therefore we need to protect virNetTLSSessionPtr with mutex to avoid non-consistent states. --- src/rpc/virnettlscontext.c | 41 +++++++++++++++++++++++++++++++++++++++-- 1 files changed, 39 insertions(+), 2 deletions(-) diff --git a/src/rpc/virnettlscontext.c b/src/rpc/virnettlscontext.c index bde4e7a..a0f7a3f 100644 --- a/src/rpc/virnettlscontext.c +++ b/src/rpc/virnettlscontext.c @@ -35,6 +35,7 @@ #include "util.h" #include "logging.h" #include "configmake.h" +#include "threads.h" #define DH_BITS 1024 @@ -63,6 +64,7 @@ struct _virNetTLSContext { }; struct _virNetTLSSession { + virMutex lock; int refs; bool handshakeComplete; @@ -1083,6 +1085,16 @@ void virNetTLSContextFree(virNetTLSContextPtr ctxt) +static void virNetTLSSessionLock(virNetTLSSessionPtr session) +{ + virMutexLock(&session->lock); +} + +static void virNetTLSSessionUnlock(virNetTLSSessionPtr session) +{ + virMutexUnlock(&session->lock); +} + static ssize_t virNetTLSSessionPush(void *opaque, const void *buf, size_t len) { @@ -1124,6 +1136,9 @@ virNetTLSSessionPtr virNetTLSSessionNew(virNetTLSContextPtr ctxt, return NULL; } + if (virMutexInit(&sess->lock) < 0) + goto error; + sess->refs = 1; if (hostname && !(sess->hostname = strdup(hostname))) { @@ -1184,7 +1199,9 @@ error: void virNetTLSSessionRef(virNetTLSSessionPtr sess) { + virNetTLSSessionLock(sess); sess->refs++; + virNetTLSSessionUnlock(sess); } void virNetTLSSessionSetIOCallbacks(virNetTLSSessionPtr sess, @@ -1192,9 +1209,11 @@ void virNetTLSSessionSetIOCallbacks(virNetTLSSessionPtr sess, virNetTLSSessionReadFunc readFunc, void *opaque) { + virNetTLSSessionLock(sess); sess->writeFunc = writeFunc; sess->readFunc = readFunc; sess->opaque = opaque; + virNetTLSSessionUnlock(sess); } @@ -1202,7 +1221,10 @@ ssize_t virNetTLSSessionWrite(virNetTLSSessionPtr sess, const char *buf, size_t len) { ssize_t ret; + + virNetTLSSessionLock(sess); ret = gnutls_record_send(sess->session, buf, len); + virNetTLSSessionUnlock(sess); if (ret >= 0) return ret; @@ -1230,7 +1252,9 @@ ssize_t virNetTLSSessionRead(virNetTLSSessionPtr sess, { ssize_t ret; + virNetTLSSessionLock(sess); ret = gnutls_record_recv(sess->session, buf, len); + virNetTLSSessionUnlock(sess); if (ret >= 0) return ret; @@ -1253,15 +1277,19 @@ ssize_t virNetTLSSessionRead(virNetTLSSessionPtr sess, int virNetTLSSessionHandshake(virNetTLSSessionPtr sess) { VIR_DEBUG("sess=%p", sess); + virNetTLSSessionLock(sess); int ret = gnutls_handshake(sess->session); VIR_DEBUG("Ret=%d", ret); if (ret == 0) { sess->handshakeComplete = true; VIR_DEBUG("Handshake is complete"); + virNetTLSSessionUnlock(sess); return 0; } - if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) + if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) { + virNetTLSSessionUnlock(sess); return 1; + } #if 0 PROBE(CLIENT_TLS_FAIL, "fd=%d", @@ -1271,6 +1299,7 @@ int virNetTLSSessionHandshake(virNetTLSSessionPtr sess) virNetError(VIR_ERR_AUTH_FAILED, _("TLS handshake failed %s"), gnutls_strerror(ret)); + virNetTLSSessionUnlock(sess); return -1; } @@ -1290,12 +1319,15 @@ int virNetTLSSessionGetKeySize(virNetTLSSessionPtr sess) gnutls_cipher_algorithm_t cipher; int ssf; + virNetTLSSessionLock(sess); cipher = gnutls_cipher_get(sess->session); if (!(ssf = gnutls_cipher_get_key_size(cipher))) { virNetError(VIR_ERR_INTERNAL_ERROR, "%s", _("invalid cipher size for TLS session")); + virNetTLSSessionUnlock(sess); return -1; } + virNetTLSSessionUnlock(sess); return ssf; } @@ -1306,11 +1338,16 @@ void virNetTLSSessionFree(virNetTLSSessionPtr sess) if (!sess) return; + virNetTLSSessionLock(sess); sess->refs--; - if (sess->refs > 0) + if (sess->refs > 0) { + virNetTLSSessionUnlock(sess); return; + } VIR_FREE(sess->hostname); gnutls_deinit(sess->session); + virNetTLSSessionUnlock(sess); + virMutexDestroy(&sess->lock); VIR_FREE(sess); } -- 1.7.5.rc3

On 07/25/2011 01:04 PM, Michal Privoznik wrote:
Some of SASL interacting functions can be called within two or more threads with the same pointer. Therefore we need to protect virNetSASLSessionPtr with mutex to avoid non-consistent states.
Looks like you and danpb had the same idea at roughly the same time. His patch appears more complete (you're missing a mutex on one of the structs), but since I'm severely sleep-deprived at the moment, I'm going to refrain from ACKing either one, as I may miss some problem. Here's Dan's version (does both SASL and TLS in one go) (I'm pointing it out here because I've already ACKed the other patches in that same series, and wanted to make sure this one didn't get lost in the sea of other replies): https://www.redhat.com/archives/libvir-list/2011-July/msg01742.html
participants (2)
-
Laine Stump
-
Michal Privoznik