From: "Daniel P. Berrange" <berrange(a)redhat.com>
Make virNetTLSContext and virNetTLSSession use the virObject
APIs for reference counting
Signed-off-by: Daniel P. Berrange <berrange(a)redhat.com>
---
daemon/libvirtd.c | 4 +-
src/libvirt_private.syms | 2 -
src/libvirt_probes.d | 8 +--
src/remote/remote_driver.c | 2 +-
src/rpc/virnetclient.c | 6 +--
src/rpc/virnetserver.c | 3 +-
src/rpc/virnetserverclient.c | 11 ++---
src/rpc/virnetserverservice.c | 10 ++--
src/rpc/virnetsocket.c | 7 ++-
src/rpc/virnettlscontext.c | 110 +++++++++++++++--------------------------
src/rpc/virnettlscontext.h | 10 +---
tests/virnettlscontexttest.c | 10 ++--
12 files changed, 66 insertions(+), 117 deletions(-)
diff --git a/daemon/libvirtd.c b/daemon/libvirtd.c
index 79f37ae..211a4bc 100644
--- a/daemon/libvirtd.c
+++ b/daemon/libvirtd.c
@@ -541,7 +541,7 @@ static int daemonSetupNetworking(virNetServerPtr srv,
false,
config->max_client_requests,
ctxt))) {
- virNetTLSContextFree(ctxt);
+ virObjectUnref(ctxt);
goto error;
}
if (virNetServerAddService(srv, svcTLS,
@@ -549,7 +549,7 @@ static int daemonSetupNetworking(virNetServerPtr srv,
!config->listen_tcp ? "_libvirt._tcp"
: NULL) < 0)
goto error;
- virNetTLSContextFree(ctxt);
+ virObjectUnref(ctxt);
}
}
diff --git a/src/libvirt_private.syms b/src/libvirt_private.syms
index 3551fd0..035658e 100644
--- a/src/libvirt_private.syms
+++ b/src/libvirt_private.syms
@@ -1481,11 +1481,9 @@ virNetSocketWrite;
# virnettlscontext.h
virNetTLSContextCheckCertificate;
-virNetTLSContextFree;
virNetTLSContextNewClient;
virNetTLSContextNewServer;
virNetTLSContextNewServerPath;
-virNetTLSSessionFree;
virNetTLSSessionHandshake;
virNetTLSSessionNew;
virNetTLSSessionSetIOCallbacks;
diff --git a/src/libvirt_probes.d b/src/libvirt_probes.d
index ceb3caa..3b138a9 100644
--- a/src/libvirt_probes.d
+++ b/src/libvirt_probes.d
@@ -61,19 +61,15 @@ provider libvirt {
# file: src/rpc/virnettlscontext.c
# prefix: rpc
- probe rpc_tls_context_new(void *ctxt, int refs, const char *cacert, const char *cacrl,
+ probe rpc_tls_context_new(void *ctxt, const char *cacert, const char *cacrl,
const char *cert, const char *key, int sanityCheckCert, int requireValidCert, int
isServer);
- probe rpc_tls_context_ref(void *ctxt, int refs);
- probe rpc_tls_context_free(void *ctxt, int refs);
probe rpc_tls_context_session_allow(void *ctxt, void *sess, const char *dname);
probe rpc_tls_context_session_deny(void *ctxt, void *sess, const char *dname);
probe rpc_tls_context_session_fail(void *ctxt, void *sess);
- probe rpc_tls_session_new(void *sess, void *ctxt, int refs, const char *hostname, int
isServer);
- probe rpc_tls_session_ref(void *sess, int refs);
- probe rpc_tls_session_free(void *sess, int refs);
+ probe rpc_tls_session_new(void *sess, void *ctxt, const char *hostname, int isServer);
probe rpc_tls_session_handshake_pass(void *sess);
probe rpc_tls_session_handshake_fail(void *sess);
diff --git a/src/remote/remote_driver.c b/src/remote/remote_driver.c
index eac50e6..28035de 100644
--- a/src/remote/remote_driver.c
+++ b/src/remote/remote_driver.c
@@ -908,7 +908,7 @@ doRemoteClose (virConnectPtr conn, struct private_data *priv)
(xdrproc_t) xdr_void, (char *) NULL) == -1)
ret = -1;
- virNetTLSContextFree(priv->tls);
+ virObjectUnref(priv->tls);
priv->tls = NULL;
virNetClientClose(priv->client);
virNetClientFree(priv->client);
diff --git a/src/rpc/virnetclient.c b/src/rpc/virnetclient.c
index 49d238e..2b51246 100644
--- a/src/rpc/virnetclient.c
+++ b/src/rpc/virnetclient.c
@@ -475,7 +475,7 @@ void virNetClientFree(virNetClientPtr client)
if (client->sock)
virNetSocketRemoveIOCallback(client->sock);
virNetSocketFree(client->sock);
- virNetTLSSessionFree(client->tls);
+ virObjectUnref(client->tls);
#if HAVE_SASL
virNetSASLSessionFree(client->sasl);
#endif
@@ -499,7 +499,7 @@ virNetClientCloseLocked(virNetClientPtr client)
virNetSocketRemoveIOCallback(client->sock);
virNetSocketFree(client->sock);
client->sock = NULL;
- virNetTLSSessionFree(client->tls);
+ virObjectUnref(client->tls);
client->tls = NULL;
#if HAVE_SASL
virNetSASLSessionFree(client->sasl);
@@ -661,7 +661,7 @@ int virNetClientSetTLSSession(virNetClientPtr client,
return 0;
error:
- virNetTLSSessionFree(client->tls);
+ virObjectUnref(client->tls);
client->tls = NULL;
virNetClientUnlock(client);
return -1;
diff --git a/src/rpc/virnetserver.c b/src/rpc/virnetserver.c
index 4a02aab..17da40c 100644
--- a/src/rpc/virnetserver.c
+++ b/src/rpc/virnetserver.c
@@ -655,8 +655,7 @@ no_memory:
int virNetServerSetTLSContext(virNetServerPtr srv,
virNetTLSContextPtr tls)
{
- srv->tls = tls;
- virNetTLSContextRef(tls);
+ srv->tls = virObjectRef(tls);
return 0;
}
diff --git a/src/rpc/virnetserverclient.c b/src/rpc/virnetserverclient.c
index a56031c..85a457e 100644
--- a/src/rpc/virnetserverclient.c
+++ b/src/rpc/virnetserverclient.c
@@ -348,7 +348,7 @@ virNetServerClientPtr virNetServerClientNew(virNetSocketPtr sock,
client->sock = sock;
client->auth = auth;
client->readonly = readonly;
- client->tlsCtxt = tls;
+ client->tlsCtxt = virObjectRef(tls);
client->nrequests_max = nrequests_max;
client->sockTimer = virEventAddTimeout(-1, virNetServerClientSockTimerFunc,
@@ -356,9 +356,6 @@ virNetServerClientPtr virNetServerClientNew(virNetSocketPtr sock,
if (client->sockTimer < 0)
goto error;
- if (tls)
- virNetTLSContextRef(tls);
-
/* Prepare one for packet receive */
if (!(client->rx = virNetMessageNew(true)))
goto error;
@@ -600,8 +597,8 @@ void virNetServerClientFree(virNetServerClientPtr client)
#endif
if (client->sockTimer > 0)
virEventRemoveTimeout(client->sockTimer);
- virNetTLSSessionFree(client->tls);
- virNetTLSContextFree(client->tlsCtxt);
+ virObjectUnref(client->tls);
+ virObjectUnref(client->tlsCtxt);
virNetSocketFree(client->sock);
virNetServerClientUnlock(client);
virMutexDestroy(&client->lock);
@@ -656,7 +653,7 @@ void virNetServerClientClose(virNetServerClientPtr client)
virNetSocketRemoveIOCallback(client->sock);
if (client->tls) {
- virNetTLSSessionFree(client->tls);
+ virObjectUnref(client->tls);
client->tls = NULL;
}
client->wantClose = true;
diff --git a/src/rpc/virnetserverservice.c b/src/rpc/virnetserverservice.c
index 28202a4..b4689b4 100644
--- a/src/rpc/virnetserverservice.c
+++ b/src/rpc/virnetserverservice.c
@@ -116,9 +116,7 @@ virNetServerServicePtr virNetServerServiceNewTCP(const char
*nodename,
svc->auth = auth;
svc->readonly = readonly;
svc->nrequests_client_max = nrequests_client_max;
- svc->tls = tls;
- if (tls)
- virNetTLSContextRef(tls);
+ svc->tls = virObjectRef(tls);
if (virNetSocketNewListenTCP(nodename,
service,
@@ -172,9 +170,7 @@ virNetServerServicePtr virNetServerServiceNewUNIX(const char *path,
svc->auth = auth;
svc->readonly = readonly;
svc->nrequests_client_max = nrequests_client_max;
- svc->tls = tls;
- if (tls)
- virNetTLSContextRef(tls);
+ svc->tls = virObjectRef(tls);
svc->nsocks = 1;
if (VIR_ALLOC_N(svc->socks, svc->nsocks) < 0)
@@ -265,7 +261,7 @@ void virNetServerServiceFree(virNetServerServicePtr svc)
virNetSocketFree(svc->socks[i]);
VIR_FREE(svc->socks);
- virNetTLSContextFree(svc->tls);
+ virObjectUnref(svc->tls);
VIR_FREE(svc);
}
diff --git a/src/rpc/virnetsocket.c b/src/rpc/virnetsocket.c
index 0b32ffe..a851dad 100644
--- a/src/rpc/virnetsocket.c
+++ b/src/rpc/virnetsocket.c
@@ -748,7 +748,7 @@ void virNetSocketFree(virNetSocketPtr sock)
/* Make sure it can't send any more I/O during shutdown */
if (sock->tlsSession)
virNetTLSSessionSetIOCallbacks(sock->tlsSession, NULL, NULL, NULL);
- virNetTLSSessionFree(sock->tlsSession);
+ virObjectUnref(sock->tlsSession);
#if HAVE_SASL
virNetSASLSessionFree(sock->saslSession);
#endif
@@ -909,13 +909,12 @@ void virNetSocketSetTLSSession(virNetSocketPtr sock,
virNetTLSSessionPtr sess)
{
virMutexLock(&sock->lock);
- virNetTLSSessionFree(sock->tlsSession);
- sock->tlsSession = sess;
+ virObjectUnref(sock->tlsSession);
+ sock->tlsSession = virObjectRef(sess);
virNetTLSSessionSetIOCallbacks(sess,
virNetSocketTLSSessionWrite,
virNetSocketTLSSessionRead,
sock);
- virNetTLSSessionRef(sess);
virMutexUnlock(&sock->lock);
}
diff --git a/src/rpc/virnettlscontext.c b/src/rpc/virnettlscontext.c
index bf92088..74e13c7 100644
--- a/src/rpc/virnettlscontext.c
+++ b/src/rpc/virnettlscontext.c
@@ -53,8 +53,9 @@
__FUNCTION__, __LINE__, __VA_ARGS__)
struct _virNetTLSContext {
+ virObject object;
+
virMutex lock;
- int refs;
gnutls_certificate_credentials_t x509cred;
gnutls_dh_params_t dhParams;
@@ -65,9 +66,9 @@ struct _virNetTLSContext {
};
struct _virNetTLSSession {
- virMutex lock;
+ virObject object;
- int refs;
+ virMutex lock;
bool handshakeComplete;
@@ -79,6 +80,29 @@ struct _virNetTLSSession {
void *opaque;
};
+static virClassPtr virNetTLSContextClass;
+static virClassPtr virNetTLSSessionClass;
+static void virNetTLSContextDispose(void *obj);
+static void virNetTLSSessionDispose(void *obj);
+
+
+static int virNetTLSContextOnceInit(void)
+{
+ if (!(virNetTLSContextClass = virClassNew("virNetTLSContext",
+ sizeof(virNetTLSContext),
+ virNetTLSContextDispose)))
+ return -1;
+
+ if (!(virNetTLSSessionClass = virClassNew("virNetTLSSession",
+ sizeof(virNetTLSSession),
+ virNetTLSSessionDispose)))
+ return -1;
+
+ return 0;
+}
+
+VIR_ONCE_GLOBAL_INIT(virNetTLSContext)
+
static int
virNetTLSContextCheckCertFile(const char *type, const char *file, bool allowMissing)
@@ -650,10 +674,11 @@ static virNetTLSContextPtr virNetTLSContextNew(const char *cacert,
char *gnutlsdebug;
int err;
- if (VIR_ALLOC(ctxt) < 0) {
- virReportOOMError();
+ if (virNetTLSContextInitialize() < 0)
+ return NULL;
+
+ if (!(ctxt = virObjectNew(virNetTLSContextClass)))
return NULL;
- }
if (virMutexInit(&ctxt->lock) < 0) {
virNetError(VIR_ERR_INTERNAL_ERROR, "%s",
@@ -662,8 +687,6 @@ static virNetTLSContextPtr virNetTLSContextNew(const char *cacert,
return NULL;
}
- ctxt->refs = 1;
-
if ((gnutlsdebug = getenv("LIBVIRT_GNUTLS_DEBUG")) != NULL) {
int val;
if (virStrToLong_i(gnutlsdebug, NULL, 10, &val) < 0)
@@ -719,8 +742,8 @@ static virNetTLSContextPtr virNetTLSContextNew(const char *cacert,
ctxt->isServer = isServer;
PROBE(RPC_TLS_CONTEXT_NEW,
- "ctxt=%p refs=%d cacert=%s cacrl=%s cert=%s key=%s sanityCheckCert=%d
requireValidCert=%d isServer=%d",
- ctxt, ctxt->refs, cacert, NULLSTR(cacrl), cert, key, sanityCheckCert,
requireValidCert, isServer);
+ "ctxt=%p cacert=%s cacrl=%s cert=%s key=%s sanityCheckCert=%d
requireValidCert=%d isServer=%d",
+ ctxt, cacert, NULLSTR(cacrl), cert, key, sanityCheckCert, requireValidCert,
isServer);
return ctxt;
@@ -930,17 +953,6 @@ virNetTLSContextPtr virNetTLSContextNewClient(const char *cacert,
}
-void virNetTLSContextRef(virNetTLSContextPtr ctxt)
-{
- virMutexLock(&ctxt->lock);
- ctxt->refs++;
- PROBE(RPC_TLS_CONTEXT_REF,
- "ctxt=%p refs=%d",
- ctxt, ctxt->refs);
- virMutexUnlock(&ctxt->lock);
-}
-
-
static int virNetTLSContextValidCertificate(virNetTLSContextPtr ctxt,
virNetTLSSessionPtr sess)
{
@@ -1109,30 +1121,16 @@ cleanup:
return ret;
}
-void virNetTLSContextFree(virNetTLSContextPtr ctxt)
+void virNetTLSContextDispose(void *obj)
{
- if (!ctxt)
- return;
-
- virMutexLock(&ctxt->lock);
- PROBE(RPC_TLS_CONTEXT_FREE,
- "ctxt=%p refs=%d",
- ctxt, ctxt->refs);
- ctxt->refs--;
- if (ctxt->refs > 0) {
- virMutexUnlock(&ctxt->lock);
- return;
- }
+ virNetTLSContextPtr ctxt = obj;
gnutls_dh_params_deinit(ctxt->dhParams);
gnutls_certificate_free_credentials(ctxt->x509cred);
- virMutexUnlock(&ctxt->lock);
virMutexDestroy(&ctxt->lock);
- VIR_FREE(ctxt);
}
-
static ssize_t
virNetTLSSessionPush(void *opaque, const void *buf, size_t len)
{
@@ -1170,10 +1168,8 @@ virNetTLSSessionPtr virNetTLSSessionNew(virNetTLSContextPtr ctxt,
VIR_DEBUG("ctxt=%p hostname=%s isServer=%d",
ctxt, NULLSTR(hostname), ctxt->isServer);
- if (VIR_ALLOC(sess) < 0) {
- virReportOOMError();
+ if (!(sess = virObjectNew(virNetTLSSessionClass)))
return NULL;
- }
if (virMutexInit(&sess->lock) < 0) {
virNetError(VIR_ERR_INTERNAL_ERROR, "%s",
@@ -1182,7 +1178,6 @@ virNetTLSSessionPtr virNetTLSSessionNew(virNetTLSContextPtr ctxt,
return NULL;
}
- sess->refs = 1;
if (hostname &&
!(sess->hostname = strdup(hostname))) {
virReportOOMError();
@@ -1233,27 +1228,17 @@ virNetTLSSessionPtr virNetTLSSessionNew(virNetTLSContextPtr ctxt,
sess->isServer = ctxt->isServer;
PROBE(RPC_TLS_SESSION_NEW,
- "sess=%p refs=%d ctxt=%p hostname=%s isServer=%d",
- sess, sess->refs, ctxt, hostname, sess->isServer);
+ "sess=%p ctxt=%p hostname=%s isServer=%d",
+ sess, ctxt, hostname, sess->isServer);
return sess;
error:
- virNetTLSSessionFree(sess);
+ virObjectUnref(sess);
return NULL;
}
-void virNetTLSSessionRef(virNetTLSSessionPtr sess)
-{
- virMutexLock(&sess->lock);
- sess->refs++;
- PROBE(RPC_TLS_SESSION_REF,
- "sess=%p refs=%d",
- sess, sess->refs);
- virMutexUnlock(&sess->lock);
-}
-
void virNetTLSSessionSetIOCallbacks(virNetTLSSessionPtr sess,
virNetTLSSessionWriteFunc writeFunc,
virNetTLSSessionReadFunc readFunc,
@@ -1396,26 +1381,13 @@ cleanup:
}
-void virNetTLSSessionFree(virNetTLSSessionPtr sess)
+void virNetTLSSessionDispose(void *obj)
{
- if (!sess)
- return;
-
- virMutexLock(&sess->lock);
- PROBE(RPC_TLS_SESSION_FREE,
- "sess=%p refs=%d",
- sess, sess->refs);
- sess->refs--;
- if (sess->refs > 0) {
- virMutexUnlock(&sess->lock);
- return;
- }
+ virNetTLSSessionPtr sess = obj;
VIR_FREE(sess->hostname);
gnutls_deinit(sess->session);
- virMutexUnlock(&sess->lock);
virMutexDestroy(&sess->lock);
- VIR_FREE(sess);
}
/*
diff --git a/src/rpc/virnettlscontext.h b/src/rpc/virnettlscontext.h
index fdfce6d..4821016 100644
--- a/src/rpc/virnettlscontext.h
+++ b/src/rpc/virnettlscontext.h
@@ -22,6 +22,7 @@
# define __VIR_NET_TLS_CONTEXT_H__
# include "internal.h"
+# include "virobject.h"
typedef struct _virNetTLSContext virNetTLSContext;
typedef virNetTLSContext *virNetTLSContextPtr;
@@ -58,13 +59,9 @@ virNetTLSContextPtr virNetTLSContextNewClient(const char *cacert,
bool sanityCheckCert,
bool requireValidCert);
-void virNetTLSContextRef(virNetTLSContextPtr ctxt);
-
int virNetTLSContextCheckCertificate(virNetTLSContextPtr ctxt,
virNetTLSSessionPtr sess);
-void virNetTLSContextFree(virNetTLSContextPtr ctxt);
-
typedef ssize_t (*virNetTLSSessionWriteFunc)(const char *buf, size_t len,
void *opaque);
@@ -79,8 +76,6 @@ void virNetTLSSessionSetIOCallbacks(virNetTLSSessionPtr sess,
virNetTLSSessionReadFunc readFunc,
void *opaque);
-void virNetTLSSessionRef(virNetTLSSessionPtr sess);
-
ssize_t virNetTLSSessionWrite(virNetTLSSessionPtr sess,
const char *buf, size_t len);
ssize_t virNetTLSSessionRead(virNetTLSSessionPtr sess,
@@ -99,7 +94,4 @@ virNetTLSSessionGetHandshakeStatus(virNetTLSSessionPtr sess);
int virNetTLSSessionGetKeySize(virNetTLSSessionPtr sess);
-void virNetTLSSessionFree(virNetTLSSessionPtr sess);
-
-
#endif
diff --git a/tests/virnettlscontexttest.c b/tests/virnettlscontexttest.c
index e745487..32e1f77 100644
--- a/tests/virnettlscontexttest.c
+++ b/tests/virnettlscontexttest.c
@@ -496,7 +496,7 @@ static int testTLSContextInit(const void *opaque)
ret = 0;
cleanup:
- virNetTLSContextFree(ctxt);
+ virObjectUnref(ctxt);
gnutls_x509_crt_deinit(data->careq.crt);
gnutls_x509_crt_deinit(data->certreq.crt);
data->careq.crt = data->certreq.crt = NULL;
@@ -710,10 +710,10 @@ static int testTLSSessionInit(const void *opaque)
ret = 0;
cleanup:
- virNetTLSContextFree(serverCtxt);
- virNetTLSContextFree(clientCtxt);
- virNetTLSSessionFree(serverSess);
- virNetTLSSessionFree(clientSess);
+ virObjectUnref(serverCtxt);
+ virObjectUnref(clientCtxt);
+ virObjectUnref(serverSess);
+ virObjectUnref(clientSess);
gnutls_x509_crt_deinit(data->careq.crt);
if (data->othercareq.filename)
gnutls_x509_crt_deinit(data->othercareq.crt);
--
1.7.10.2