From: "Daniel P. Berrange" <berrange(a)redhat.com>
Remove the need for a virNetSocket object to be protected by
locks from the object using it, by introducing its own native
locking and reference counting
* src/rpc/virnetsocket.c: Add locking & reference counting
---
src/rpc/virnetsocket.c | 147 +++++++++++++++++++++++++++++++++++++++---------
1 files changed, 120 insertions(+), 27 deletions(-)
diff --git a/src/rpc/virnetsocket.c b/src/rpc/virnetsocket.c
index 7ea1ab7..8dd4d3a 100644
--- a/src/rpc/virnetsocket.c
+++ b/src/rpc/virnetsocket.c
@@ -40,6 +40,7 @@
#include "logging.h"
#include "files.h"
#include "event.h"
+#include "threads.h"
#define VIR_FROM_THIS VIR_FROM_RPC
@@ -49,6 +50,9 @@
struct _virNetSocket {
+ virMutex lock;
+ int refs;
+
int fd;
int watch;
pid_t pid;
@@ -122,6 +126,14 @@ static virNetSocketPtr virNetSocketNew(virSocketAddrPtr localAddr,
return NULL;
}
+ if (virMutexInit(&sock->lock) < 0) {
+ virReportSystemError(errno, "%s",
+ _("Unable to initialize mutex"));
+ VIR_FREE(sock);
+ return NULL;
+ }
+ sock->refs = 1;
+
if (localAddr)
sock->localAddr = *localAddr;
if (remoteAddr)
@@ -627,6 +639,13 @@ void virNetSocketFree(virNetSocketPtr sock)
if (!sock)
return;
+ virMutexLock(&sock->lock);
+ sock->refs--;
+ if (sock->refs > 0) {
+ virMutexUnlock(&sock->lock);
+ return;
+ }
+
VIR_DEBUG("sock=%p fd=%d", sock, sock->fd);
if (sock->watch > 0) {
virEventRemoveHandle(sock->watch);
@@ -657,27 +676,41 @@ void virNetSocketFree(virNetSocketPtr sock)
VIR_FREE(sock->localAddrStr);
VIR_FREE(sock->remoteAddrStr);
+ virMutexUnlock(&sock->lock);
+ virMutexDestroy(&sock->lock);
+
VIR_FREE(sock);
}
int virNetSocketGetFD(virNetSocketPtr sock)
{
- return sock->fd;
+ int fd;
+ virMutexLock(&sock->lock);
+ fd = sock->fd;
+ virMutexUnlock(&sock->lock);
+ return fd;
}
bool virNetSocketIsLocal(virNetSocketPtr sock)
{
+ bool isLocal = false;
+ virMutexLock(&sock->lock);
if (sock->localAddr.data.sa.sa_family == AF_UNIX)
- return true;
- return false;
+ isLocal = true;
+ virMutexUnlock(&sock->lock);
+ return isLocal;
}
int virNetSocketGetPort(virNetSocketPtr sock)
{
- return virSocketGetPort(&sock->localAddr);
+ int port;
+ virMutexLock(&sock->lock);
+ port = virSocketGetPort(&sock->localAddr);
+ virMutexUnlock(&sock->lock);
+ return port;
}
@@ -688,15 +721,19 @@ int virNetSocketGetLocalIdentity(virNetSocketPtr sock,
{
struct ucred cr;
unsigned int cr_len = sizeof (cr);
+ virMutexLock(&sock->lock);
if (getsockopt(sock->fd, SOL_SOCKET, SO_PEERCRED, &cr, &cr_len) < 0) {
virReportSystemError(errno, "%s",
_("Failed to get client socket identity"));
+ virMutexUnlock(&sock->lock);
return -1;
}
*pid = cr.pid;
*uid = cr.uid;
+
+ virMutexUnlock(&sock->lock);
return 0;
}
#else
@@ -715,7 +752,11 @@ int virNetSocketGetLocalIdentity(virNetSocketPtr sock
ATTRIBUTE_UNUSED,
int virNetSocketSetBlocking(virNetSocketPtr sock,
bool blocking)
{
- return virSetBlocking(sock->fd, blocking);
+ int ret;
+ virMutexLock(&sock->lock);
+ ret = virSetBlocking(sock->fd, blocking);
+ virMutexUnlock(&sock->lock);
+ return ret;
}
@@ -751,6 +792,7 @@ static ssize_t virNetSocketTLSSessionRead(char *buf,
void virNetSocketSetTLSSession(virNetSocketPtr sock,
virNetTLSSessionPtr sess)
{
+ virMutexLock(&sock->lock);
virNetTLSSessionFree(sock->tlsSession);
sock->tlsSession = sess;
virNetTLSSessionSetIOCallbacks(sess,
@@ -758,6 +800,7 @@ void virNetSocketSetTLSSession(virNetSocketPtr sock,
virNetSocketTLSSessionRead,
sock);
virNetTLSSessionRef(sess);
+ virMutexUnlock(&sock->lock);
}
@@ -765,20 +808,25 @@ void virNetSocketSetTLSSession(virNetSocketPtr sock,
void virNetSocketSetSASLSession(virNetSocketPtr sock,
virNetSASLSessionPtr sess)
{
+ virMutexLock(&sock->lock);
virNetSASLSessionFree(sock->saslSession);
sock->saslSession = sess;
virNetSASLSessionRef(sess);
+ virMutexUnlock(&sock->lock);
}
#endif
bool virNetSocketHasCachedData(virNetSocketPtr sock ATTRIBUTE_UNUSED)
{
+ bool hasCached = false;
+ virMutexLock(&sock->lock);
#if HAVE_SASL
if (sock->saslDecoded)
- return true;
+ hasCached = true;
#endif
- return false;
+ virMutexUnlock(&sock->lock);
+ return hasCached;
}
@@ -965,39 +1013,54 @@ static ssize_t virNetSocketWriteSASL(virNetSocketPtr sock, const
char *buf, size
ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len)
{
+ ssize_t ret;
+ virMutexLock(&sock->lock);
#if HAVE_SASL
if (sock->saslSession)
- return virNetSocketReadSASL(sock, buf, len);
+ ret = virNetSocketReadSASL(sock, buf, len);
else
#endif
- return virNetSocketReadWire(sock, buf, len);
+ ret = virNetSocketReadWire(sock, buf, len);
+ virMutexUnlock(&sock->lock);
+ return ret;
}
ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len)
{
+ ssize_t ret;
+
+ virMutexLock(&sock->lock);
#if HAVE_SASL
if (sock->saslSession)
- return virNetSocketWriteSASL(sock, buf, len);
+ ret = virNetSocketWriteSASL(sock, buf, len);
else
#endif
- return virNetSocketWriteWire(sock, buf, len);
+ ret = virNetSocketWriteWire(sock, buf, len);
+ virMutexUnlock(&sock->lock);
+ return ret;
}
int virNetSocketListen(virNetSocketPtr sock)
{
+ virMutexLock(&sock->lock);
if (listen(sock->fd, 30) < 0) {
virReportSystemError(errno, "%s", _("Unable to listen on
socket"));
+ virMutexUnlock(&sock->lock);
return -1;
}
+ virMutexUnlock(&sock->lock);
return 0;
}
int virNetSocketAccept(virNetSocketPtr sock, virNetSocketPtr *clientsock)
{
- int fd;
+ int fd = -1;
virSocketAddr localAddr;
virSocketAddr remoteAddr;
+ int ret = -1;
+
+ virMutexLock(&sock->lock);
*clientsock = NULL;
@@ -1007,30 +1070,35 @@ int virNetSocketAccept(virNetSocketPtr sock, virNetSocketPtr
*clientsock)
remoteAddr.len = sizeof(remoteAddr.data.stor);
if ((fd = accept(sock->fd, &remoteAddr.data.sa, &remoteAddr.len)) < 0)
{
if (errno == ECONNABORTED ||
- errno == EAGAIN)
- return 0;
+ errno == EAGAIN) {
+ ret = 0;
+ goto cleanup;
+ }
virReportSystemError(errno, "%s",
_("Unable to accept client"));
- return -1;
+ goto cleanup;
}
localAddr.len = sizeof(localAddr.data);
if (getsockname(fd, &localAddr.data.sa, &localAddr.len) < 0) {
virReportSystemError(errno, "%s", _("Unable to get local socket
name"));
- VIR_FORCE_CLOSE(fd);
- return -1;
+ goto cleanup;
}
if (!(*clientsock = virNetSocketNew(&localAddr,
&remoteAddr,
true,
- fd, -1, 0))) {
- VIR_FORCE_CLOSE(fd);
- return -1;
- }
+ fd, -1, 0)))
+ goto cleanup;
- return 0;
+ fd = -1;
+ ret = 0;
+
+cleanup:
+ VIR_FORCE_CLOSE(fd);
+ virMutexUnlock(&sock->lock);
+ return ret;
}
@@ -1040,52 +1108,77 @@ static void virNetSocketEventHandle(int watch ATTRIBUTE_UNUSED,
void *opaque)
{
virNetSocketPtr sock = opaque;
+ virNetSocketIOFunc func;
+ void *eopaque;
- sock->func(sock, events, sock->opaque);
+ virMutexLock(&sock->lock);
+ func = sock->func;
+ eopaque = sock->opaque;
+ virMutexUnlock(&sock->lock);
+
+ if (func)
+ func(sock, events, eopaque);
}
+
int virNetSocketAddIOCallback(virNetSocketPtr sock,
int events,
virNetSocketIOFunc func,
void *opaque)
{
+ int ret = -1;
+
+ virMutexLock(&sock->lock);
if (sock->watch > 0) {
VIR_DEBUG("Watch already registered on socket %p", sock);
- return -1;
+ goto cleanup;
}
+ sock->refs++;
if ((sock->watch = virEventAddHandle(sock->fd,
events,
virNetSocketEventHandle,
sock,
NULL)) < 0) {
VIR_DEBUG("Failed to register watch on socket %p", sock);
- return -1;
+ goto cleanup;
}
sock->func = func;
sock->opaque = opaque;
- return 0;
+ ret = 0;
+
+cleanup:
+ virMutexUnlock(&sock->lock);
+ return ret;
}
void virNetSocketUpdateIOCallback(virNetSocketPtr sock,
int events)
{
+ virMutexLock(&sock->lock);
if (sock->watch <= 0) {
VIR_DEBUG("Watch not registered on socket %p", sock);
+ virMutexUnlock(&sock->lock);
return;
}
virEventUpdateHandle(sock->watch, events);
+
+ virMutexUnlock(&sock->lock);
}
void virNetSocketRemoveIOCallback(virNetSocketPtr sock)
{
+ virMutexLock(&sock->lock);
+
if (sock->watch <= 0) {
VIR_DEBUG("Watch not registered on socket %p", sock);
+ virMutexUnlock(&sock->lock);
return;
}
virEventRemoveHandle(sock->watch);
- sock->watch = 0;
+
+ virMutexUnlock(&sock->lock);
}
--
1.7.6