Signed-off-by: Tim Wiederhake <twiederh(a)redhat.com>
---
src/rpc/virnetserverclient.c | 432 +++++++++++++++--------------------
1 file changed, 186 insertions(+), 246 deletions(-)
diff --git a/src/rpc/virnetserverclient.c b/src/rpc/virnetserverclient.c
index 7d5c0965b8..da9956f2b4 100644
--- a/src/rpc/virnetserverclient.c
+++ b/src/rpc/virnetserverclient.c
@@ -234,14 +234,12 @@ int virNetServerClientAddFilter(virNetServerClient *client,
virNetServerClientFilterFunc func,
void *opaque)
{
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
virNetServerClientFilter *filter;
virNetServerClientFilter **place;
- int ret;
filter = g_new0(virNetServerClientFilter, 1);
- virObjectLock(client);
-
filter->id = client->nextFilterID++;
filter->func = func;
filter->opaque = opaque;
@@ -251,21 +249,16 @@ int virNetServerClientAddFilter(virNetServerClient *client,
place = &(*place)->next;
*place = filter;
- ret = filter->id;
-
- virObjectUnlock(client);
-
- return ret;
+ return filter->id;
}
void virNetServerClientRemoveFilter(virNetServerClient *client,
int filterID)
{
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
virNetServerClientFilter *tmp;
virNetServerClientFilter *prev;
- virObjectLock(client);
-
prev = NULL;
tmp = client->filters;
while (tmp) {
@@ -281,8 +274,6 @@ void virNetServerClientRemoveFilter(virNetServerClient *client,
prev = tmp;
tmp = tmp->next;
}
-
- virObjectUnlock(client);
}
@@ -322,19 +313,19 @@ virNetServerClientCheckAccess(virNetServerClient *client)
static void virNetServerClientDispatchMessage(virNetServerClient *client,
virNetMessage *msg)
{
- virObjectLock(client);
- if (!client->dispatchFunc) {
- virNetMessageFree(msg);
- client->wantClose = true;
- virObjectUnlock(client);
- } else {
- virObjectUnlock(client);
- /* Accessing 'client' is safe, because virNetServerClientSetDispatcher
- * only permits setting 'dispatchFunc' once, so if non-NULL, it will
- * never change again
- */
- client->dispatchFunc(client, msg, client->dispatchOpaque);
+ VIR_WITH_OBJECT_LOCK_GUARD(client) {
+ if (!client->dispatchFunc) {
+ virNetMessageFree(msg);
+ client->wantClose = true;
+ return;
+ }
}
+
+ /* Accessing 'client' is safe, because virNetServerClientSetDispatcher
+ * only permits setting 'dispatchFunc' once, so if non-NULL, it will
+ * never change again
+ */
+ client->dispatchFunc(client, msg, client->dispatchOpaque);
}
@@ -343,13 +334,14 @@ static void virNetServerClientSockTimerFunc(int timer,
{
virNetServerClient *client = opaque;
virNetMessage *msg = NULL;
- virObjectLock(client);
- virEventUpdateTimeout(timer, -1);
- /* Although client->rx != NULL when this timer is enabled, it might have
- * changed since the client was unlocked in the meantime. */
- if (client->rx)
- msg = virNetServerClientDispatchRead(client);
- virObjectUnlock(client);
+
+ VIR_WITH_OBJECT_LOCK_GUARD(client) {
+ virEventUpdateTimeout(timer, -1);
+ /* Although client->rx != NULL when this timer is enabled, it might have
+ * changed since the client was unlocked in the meantime. */
+ if (client->rx)
+ msg = virNetServerClientDispatchRead(client);
+ }
if (msg)
virNetServerClientDispatchMessage(client, msg);
@@ -587,53 +579,45 @@ virJSONValue *virNetServerClientPreExecRestart(virNetServerClient
*client)
g_autoptr(virJSONValue) object = virJSONValueNewObject();
g_autoptr(virJSONValue) sock = NULL;
g_autoptr(virJSONValue) priv = NULL;
-
- virObjectLock(client);
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
if (virJSONValueObjectAppendNumberUlong(object, "id", client->id) <
0)
- goto error;
+ return NULL;
if (virJSONValueObjectAppendNumberInt(object, "auth", client->auth) <
0)
- goto error;
+ return NULL;
if (virJSONValueObjectAppendBoolean(object, "auth_pending",
client->auth_pending) < 0)
- goto error;
+ return NULL;
if (virJSONValueObjectAppendBoolean(object, "readonly",
client->readonly) < 0)
- goto error;
+ return NULL;
if (virJSONValueObjectAppendNumberUint(object, "nrequests_max",
client->nrequests_max) < 0)
- goto error;
+ return NULL;
if (client->conn_time &&
virJSONValueObjectAppendNumberLong(object, "conn_time",
client->conn_time) < 0)
- goto error;
+ return NULL;
if (!(sock = virNetSocketPreExecRestart(client->sock)))
- goto error;
+ return NULL;
if (virJSONValueObjectAppend(object, "sock", &sock) < 0)
- goto error;
+ return NULL;
if (!(priv = client->privateDataPreExecRestart(client, client->privateData)))
- goto error;
+ return NULL;
if (virJSONValueObjectAppend(object, "privateData", &priv) < 0)
- goto error;
+ return NULL;
- virObjectUnlock(client);
return g_steal_pointer(&object);
-
- error:
- virObjectUnlock(client);
- return NULL;
}
int virNetServerClientGetAuth(virNetServerClient *client)
{
- int auth;
- virObjectLock(client);
- auth = client->auth;
- virObjectUnlock(client);
- return auth;
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
+
+ return client->auth;
}
@@ -647,11 +631,9 @@ virNetServerClientSetAuthLocked(virNetServerClient *client,
bool virNetServerClientGetReadonly(virNetServerClient *client)
{
- bool readonly;
- virObjectLock(client);
- readonly = client->readonly;
- virObjectUnlock(client);
- return readonly;
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
+
+ return client->readonly;
}
@@ -659,9 +641,9 @@ void
virNetServerClientSetReadonly(virNetServerClient *client,
bool readonly)
{
- virObjectLock(client);
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
+
client->readonly = readonly;
- virObjectUnlock(client);
}
@@ -677,52 +659,48 @@ long long virNetServerClientGetTimestamp(virNetServerClient
*client)
bool virNetServerClientHasTLSSession(virNetServerClient *client)
{
- bool has;
- virObjectLock(client);
- has = client->tls ? true : false;
- virObjectUnlock(client);
- return has;
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
+
+ return !!client->tls;
}
virNetTLSSession *virNetServerClientGetTLSSession(virNetServerClient *client)
{
- virNetTLSSession *tls;
- virObjectLock(client);
- tls = client->tls;
- virObjectUnlock(client);
- return tls;
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
+
+ return client->tls;
}
int virNetServerClientGetTLSKeySize(virNetServerClient *client)
{
- int size = 0;
- virObjectLock(client);
- if (client->tls)
- size = virNetTLSSessionGetKeySize(client->tls);
- virObjectUnlock(client);
- return size;
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
+
+ if (!client->tls)
+ return 0;
+
+ return virNetTLSSessionGetKeySize(client->tls);
}
int virNetServerClientGetFD(virNetServerClient *client)
{
- int fd = -1;
- virObjectLock(client);
- if (client->sock)
- fd = virNetSocketGetFD(client->sock);
- virObjectUnlock(client);
- return fd;
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
+
+ if (!client->sock)
+ return -1;
+
+ return virNetSocketGetFD(client->sock);
}
bool virNetServerClientIsLocal(virNetServerClient *client)
{
- bool local = false;
- virObjectLock(client);
- if (client->sock)
- local = virNetSocketIsLocal(client->sock);
- virObjectUnlock(client);
- return local;
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
+
+ if (!client->sock)
+ return false;
+
+ return virNetSocketIsLocal(client->sock);
}
@@ -730,14 +708,12 @@ int virNetServerClientGetUNIXIdentity(virNetServerClient *client,
uid_t *uid, gid_t *gid, pid_t *pid,
unsigned long long *timestamp)
{
- int ret = -1;
- virObjectLock(client);
- if (client->sock)
- ret = virNetSocketGetUNIXIdentity(client->sock,
- uid, gid, pid,
- timestamp);
- virObjectUnlock(client);
- return ret;
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
+
+ if (!client->sock)
+ return -1;
+
+ return virNetSocketGetUNIXIdentity(client->sock, uid, gid, pid, timestamp);
}
@@ -806,56 +782,60 @@ virNetServerClientCreateIdentity(virNetServerClient *client)
virIdentity *virNetServerClientGetIdentity(virNetServerClient *client)
{
- virIdentity *ret = NULL;
- virObjectLock(client);
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
+
if (!client->identity)
client->identity = virNetServerClientCreateIdentity(client);
- if (client->identity)
- ret = g_object_ref(client->identity);
- virObjectUnlock(client);
- return ret;
+
+ if (!client->identity)
+ return NULL;
+
+ return g_object_ref(client->identity);
}
void virNetServerClientSetIdentity(virNetServerClient *client,
virIdentity *identity)
{
- virObjectLock(client);
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
+
g_clear_object(&client->identity);
client->identity = identity;
if (client->identity)
g_object_ref(client->identity);
- virObjectUnlock(client);
}
int virNetServerClientGetSELinuxContext(virNetServerClient *client,
char **context)
{
- int ret = 0;
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
+
*context = NULL;
- virObjectLock(client);
- if (client->sock)
- ret = virNetSocketGetSELinuxContext(client->sock, context);
- virObjectUnlock(client);
- return ret;
+
+ if (!client->sock)
+ return 0;
+
+ return virNetSocketGetSELinuxContext(client->sock, context);
}
bool virNetServerClientIsSecure(virNetServerClient *client)
{
- bool secure = false;
- virObjectLock(client);
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
+
if (client->tls)
- secure = true;
+ return true;
+
#if WITH_SASL
if (client->sasl)
- secure = true;
+ return true;
#endif
+
if (client->sock && virNetSocketIsLocal(client->sock))
- secure = true;
- virObjectUnlock(client);
- return secure;
+ return true;
+
+ return false;
}
@@ -863,53 +843,47 @@ bool virNetServerClientIsSecure(virNetServerClient *client)
void virNetServerClientSetSASLSession(virNetServerClient *client,
virNetSASLSession *sasl)
{
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
+
/* We don't set the sasl session on the socket here
* because we need to send out the auth confirmation
* in the clear. Only once we complete the next 'tx'
* operation do we switch to SASL mode
*/
- virObjectLock(client);
client->sasl = virObjectRef(sasl);
- virObjectUnlock(client);
}
virNetSASLSession *virNetServerClientGetSASLSession(virNetServerClient *client)
{
- virNetSASLSession *sasl;
- virObjectLock(client);
- sasl = client->sasl;
- virObjectUnlock(client);
- return sasl;
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
+
+ return client->sasl;
}
bool virNetServerClientHasSASLSession(virNetServerClient *client)
{
- bool has = false;
- virObjectLock(client);
- has = !!client->sasl;
- virObjectUnlock(client);
- return has;
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
+
+ return !!client->sasl;
}
#endif
void *virNetServerClientGetPrivateData(virNetServerClient *client)
{
- void *data;
- virObjectLock(client);
- data = client->privateData;
- virObjectUnlock(client);
- return data;
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
+
+ return client->privateData;
}
void virNetServerClientSetCloseHook(virNetServerClient *client,
virNetServerClientCloseFunc cf)
{
- virObjectLock(client);
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
+
client->privateDataCloseFunc = cf;
- virObjectUnlock(client);
}
@@ -917,7 +891,8 @@ void virNetServerClientSetDispatcher(virNetServerClient *client,
virNetServerClientDispatchFunc func,
void *opaque)
{
- virObjectLock(client);
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
+
/* Only set dispatcher if not already set, to avoid race
* with dispatch code that runs without locks held
*/
@@ -925,7 +900,6 @@ void virNetServerClientSetDispatcher(virNetServerClient *client,
client->dispatchFunc = func;
client->dispatchOpaque = opaque;
}
- virObjectUnlock(client);
}
@@ -1042,9 +1016,9 @@ virNetServerClientCloseLocked(virNetServerClient *client)
void
virNetServerClientClose(virNetServerClient *client)
{
- virObjectLock(client);
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
+
virNetServerClientCloseLocked(client);
- virObjectUnlock(client);
}
@@ -1057,16 +1031,16 @@ virNetServerClientIsClosedLocked(virNetServerClient *client)
void virNetServerClientDelayedClose(virNetServerClient *client)
{
- virObjectLock(client);
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
+
client->delayedClose = true;
- virObjectUnlock(client);
}
void virNetServerClientImmediateClose(virNetServerClient *client)
{
- virObjectLock(client);
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
+
client->wantClose = true;
- virObjectUnlock(client);
}
@@ -1079,49 +1053,46 @@ virNetServerClientWantCloseLocked(virNetServerClient *client)
int virNetServerClientInit(virNetServerClient *client)
{
- virObjectLock(client);
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
+ int ret = -1;
if (!client->tlsCtxt) {
/* Plain socket, so prepare to read first message */
if (virNetServerClientRegisterEvent(client) < 0)
goto error;
- } else {
- int ret;
+ return 0;
+ }
- if (!(client->tls = virNetTLSSessionNew(client->tlsCtxt,
- NULL)))
- goto error;
+ if (!(client->tls = virNetTLSSessionNew(client->tlsCtxt, NULL)))
+ goto error;
- virNetSocketSetTLSSession(client->sock,
- client->tls);
+ virNetSocketSetTLSSession(client->sock, client->tls);
- /* Begin the TLS handshake. */
- virObjectLock(client->tlsCtxt);
+ /* Begin the TLS handshake. */
+ VIR_WITH_OBJECT_LOCK_GUARD(client->tlsCtxt) {
ret = virNetTLSSessionHandshake(client->tls);
- virObjectUnlock(client->tlsCtxt);
- if (ret == 0) {
- /* Unlikely, but ... Next step is to check the certificate. */
- if (virNetServerClientCheckAccess(client) < 0)
- goto error;
-
- /* Handshake & cert check OK, so prepare to read first message */
- if (virNetServerClientRegisterEvent(client) < 0)
- goto error;
- } else if (ret > 0) {
- /* Most likely, need to do more handshake data */
- if (virNetServerClientRegisterEvent(client) < 0)
- goto error;
- } else {
+ }
+
+ if (ret == 0) {
+ /* Unlikely, but ... Next step is to check the certificate. */
+ if (virNetServerClientCheckAccess(client) < 0)
goto error;
- }
+
+ /* Handshake & cert check OK, so prepare to read first message */
+ if (virNetServerClientRegisterEvent(client) < 0)
+ goto error;
+ } else if (ret > 0) {
+ /* Most likely, need to do more handshake data */
+ if (virNetServerClientRegisterEvent(client) < 0)
+ goto error;
+ } else {
+ goto error;
}
- virObjectUnlock(client);
return 0;
error:
client->wantClose = true;
- virObjectUnlock(client);
return -1;
}
@@ -1406,11 +1377,13 @@ virNetServerClientDispatchWrite(virNetServerClient *client)
static void
virNetServerClientDispatchHandshake(virNetServerClient *client)
{
- int ret;
+ int ret = -1;
+
/* Continue the handshake. */
- virObjectLock(client->tlsCtxt);
- ret = virNetTLSSessionHandshake(client->tls);
- virObjectUnlock(client->tlsCtxt);
+ VIR_WITH_OBJECT_LOCK_GUARD(client->tlsCtxt) {
+ ret = virNetTLSSessionHandshake(client->tls);
+ }
+
if (ret == 0) {
/* Finished. Next step is to check the certificate. */
if (virNetServerClientCheckAccess(client) < 0)
@@ -1435,36 +1408,29 @@ virNetServerClientDispatchEvent(virNetSocket *sock, int events,
void *opaque)
virNetServerClient *client = opaque;
virNetMessage *msg = NULL;
- virObjectLock(client);
-
- if (client->sock != sock) {
- virNetSocketRemoveIOCallback(sock);
- virObjectUnlock(client);
- return;
- }
-
- if (events & (VIR_EVENT_HANDLE_WRITABLE |
- VIR_EVENT_HANDLE_READABLE)) {
- if (client->tls &&
- virNetTLSSessionGetHandshakeStatus(client->tls) !=
- VIR_NET_TLS_HANDSHAKE_COMPLETE) {
- virNetServerClientDispatchHandshake(client);
- } else {
- if (events & VIR_EVENT_HANDLE_WRITABLE)
- virNetServerClientDispatchWrite(client);
- if (events & VIR_EVENT_HANDLE_READABLE &&
- client->rx)
- msg = virNetServerClientDispatchRead(client);
+ VIR_WITH_OBJECT_LOCK_GUARD(client) {
+ if (client->sock != sock) {
+ virNetSocketRemoveIOCallback(sock);
+ return;
}
- }
- /* NB, will get HANGUP + READABLE at same time upon
- * disconnect */
- if (events & (VIR_EVENT_HANDLE_ERROR |
- VIR_EVENT_HANDLE_HANGUP))
- client->wantClose = true;
+ if (events & (VIR_EVENT_HANDLE_WRITABLE | VIR_EVENT_HANDLE_READABLE)) {
+ if (client->tls &&
+ virNetTLSSessionGetHandshakeStatus(client->tls) !=
+ VIR_NET_TLS_HANDSHAKE_COMPLETE) {
+ virNetServerClientDispatchHandshake(client);
+ } else {
+ if (events & VIR_EVENT_HANDLE_WRITABLE)
+ virNetServerClientDispatchWrite(client);
+ if ((events & VIR_EVENT_HANDLE_READABLE) && client->rx)
+ msg = virNetServerClientDispatchRead(client);
+ }
+ }
- virObjectUnlock(client);
+ /* NB, will get HANGUP + READABLE at same time upon disconnect */
+ if (events & (VIR_EVENT_HANDLE_ERROR | VIR_EVENT_HANDLE_HANGUP))
+ client->wantClose = true;
+ }
if (msg)
virNetServerClientDispatchMessage(client, msg);
@@ -1499,24 +1465,18 @@ virNetServerClientSendMessageLocked(virNetServerClient *client,
int virNetServerClientSendMessage(virNetServerClient *client,
virNetMessage *msg)
{
- int ret;
-
- virObjectLock(client);
- ret = virNetServerClientSendMessageLocked(client, msg);
- virObjectUnlock(client);
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
- return ret;
+ return virNetServerClientSendMessageLocked(client, msg);
}
bool
virNetServerClientIsAuthenticated(virNetServerClient *client)
{
- bool authenticated;
- virObjectLock(client);
- authenticated = virNetServerClientAuthMethodImpliesAuthenticated(client->auth);
- virObjectUnlock(client);
- return authenticated;
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
+
+ return virNetServerClientAuthMethodImpliesAuthenticated(client->auth);
}
@@ -1556,57 +1516,44 @@ virNetServerClientInitKeepAlive(virNetServerClient *client,
int interval,
unsigned int count)
{
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
virKeepAlive *ka;
- int ret = -1;
-
- virObjectLock(client);
if (!(ka = virKeepAliveNew(interval, count, client,
virNetServerClientKeepAliveSendCB,
virNetServerClientKeepAliveDeadCB,
virObjectFreeCallback)))
- goto cleanup;
+ return -1;
+
/* keepalive object has a reference to client */
virObjectRef(client);
client->keepalive = ka;
- ret = 0;
- cleanup:
- virObjectUnlock(client);
-
- return ret;
+ return 0;
}
int
virNetServerClientStartKeepAlive(virNetServerClient *client)
{
- int ret = -1;
-
- virObjectLock(client);
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
/* The connection might have been closed before we got here and thus the
* keepalive object could have been removed too.
*/
if (!client->keepalive) {
- virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
- _("connection not open"));
- goto cleanup;
+ virReportError(VIR_ERR_INTERNAL_ERROR, "%s", _("connection not
open"));
+ return -1;
}
- ret = virKeepAliveStart(client->keepalive, 0, 0);
-
- cleanup:
- virObjectUnlock(client);
- return ret;
+ return virKeepAliveStart(client->keepalive, 0, 0);
}
int
virNetServerClientGetTransport(virNetServerClient *client)
{
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
int ret = -1;
- virObjectLock(client);
-
if (client->sock && virNetSocketIsLocal(client->sock))
ret = VIR_CLIENT_TRANS_UNIX;
else
@@ -1615,8 +1562,6 @@ virNetServerClientGetTransport(virNetServerClient *client)
if (client->tls)
ret = VIR_CLIENT_TRANS_TLS;
- virObjectUnlock(client);
-
return ret;
}
@@ -1625,16 +1570,15 @@ virNetServerClientGetInfo(virNetServerClient *client,
bool *readonly, char **sock_addr,
virIdentity **identity)
{
- int ret = -1;
+ VIR_LOCK_GUARD lock = virObjectLockGuard(client);
const char *addr;
- virObjectLock(client);
*readonly = client->readonly;
if (!(addr = virNetServerClientRemoteAddrStringURI(client))) {
virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
_("No network socket associated with client"));
- goto cleanup;
+ return -1;
}
*sock_addr = g_strdup(addr);
@@ -1642,15 +1586,11 @@ virNetServerClientGetInfo(virNetServerClient *client,
if (!client->identity) {
virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
_("No identity information available for client"));
- goto cleanup;
+ return -1;
}
*identity = g_object_ref(client->identity);
-
- ret = 0;
- cleanup:
- virObjectUnlock(client);
- return ret;
+ return 0;
}
--
2.31.1