The counter gets incremented on each unauthenticated client added to the
server and decremented whenever the client authenticates.
Signed-off-by: Michal Privoznik <mprivozn(a)redhat.com>
---
daemon/remote.c | 21 +++++++++++++--------
src/rpc/virnetserver.c | 45 ++++++++++++++++++++++++++++++++++++++++++---
src/rpc/virnetserver.h | 3 +++
3 files changed, 58 insertions(+), 11 deletions(-)
diff --git a/daemon/remote.c b/daemon/remote.c
index b48d456..416aa40 100644
--- a/daemon/remote.c
+++ b/daemon/remote.c
@@ -2619,7 +2619,7 @@ cleanup:
/*-------------------------------------------------------------*/
static int
-remoteDispatchAuthList(virNetServerPtr server ATTRIBUTE_UNUSED,
+remoteDispatchAuthList(virNetServerPtr server,
virNetServerClientPtr client,
virNetMessagePtr msg ATTRIBUTE_UNUSED,
virNetMessageErrorPtr rerr,
@@ -2649,6 +2649,7 @@ remoteDispatchAuthList(virNetServerPtr server ATTRIBUTE_UNUSED,
goto cleanup;
VIR_INFO("Bypass polkit auth for privileged client %s", ident);
virNetServerClientSetAuth(client, 0);
+ virNetServerTrackCompletedAuth(server);
auth = VIR_NET_SERVER_SERVICE_AUTH_NONE;
VIR_FREE(ident);
}
@@ -2764,7 +2765,8 @@ authfail:
* Returns 0 if ok, -1 on error, -2 if rejected
*/
static int
-remoteSASLFinish(virNetServerClientPtr client)
+remoteSASLFinish(virNetServerPtr server,
+ virNetServerClientPtr client)
{
const char *identity;
struct daemonClientPrivate *priv = virNetServerClientGetPrivateData(client);
@@ -2789,6 +2791,7 @@ remoteSASLFinish(virNetServerClientPtr client)
return -2;
virNetServerClientSetAuth(client, 0);
+ virNetServerTrackCompletedAuth(server);
virNetServerClientSetSASLSession(client, priv->sasl);
VIR_DEBUG("Authentication successful %d",
virNetServerClientGetFD(client));
@@ -2810,7 +2813,7 @@ error:
* This starts the SASL authentication negotiation.
*/
static int
-remoteDispatchAuthSaslStart(virNetServerPtr server ATTRIBUTE_UNUSED,
+remoteDispatchAuthSaslStart(virNetServerPtr server,
virNetServerClientPtr client,
virNetMessagePtr msg ATTRIBUTE_UNUSED,
virNetMessageErrorPtr rerr,
@@ -2868,7 +2871,7 @@ remoteDispatchAuthSaslStart(virNetServerPtr server
ATTRIBUTE_UNUSED,
ret->complete = 0;
} else {
/* Check username whitelist ACL */
- if ((err = remoteSASLFinish(client)) < 0) {
+ if ((err = remoteSASLFinish(server, client)) < 0) {
if (err == -2)
goto authdeny;
else
@@ -2908,7 +2911,7 @@ error:
static int
-remoteDispatchAuthSaslStep(virNetServerPtr server ATTRIBUTE_UNUSED,
+remoteDispatchAuthSaslStep(virNetServerPtr server,
virNetServerClientPtr client,
virNetMessagePtr msg ATTRIBUTE_UNUSED,
virNetMessageErrorPtr rerr,
@@ -2966,7 +2969,7 @@ remoteDispatchAuthSaslStep(virNetServerPtr server ATTRIBUTE_UNUSED,
ret->complete = 0;
} else {
/* Check username whitelist ACL */
- if ((err = remoteSASLFinish(client)) < 0) {
+ if ((err = remoteSASLFinish(server, client)) < 0) {
if (err == -2)
goto authdeny;
else
@@ -3051,7 +3054,7 @@ remoteDispatchAuthSaslStep(virNetServerPtr server ATTRIBUTE_UNUSED,
#if WITH_POLKIT1
static int
-remoteDispatchAuthPolkit(virNetServerPtr server ATTRIBUTE_UNUSED,
+remoteDispatchAuthPolkit(virNetServerPtr server,
virNetServerClientPtr client,
virNetMessagePtr msg ATTRIBUTE_UNUSED,
virNetMessageErrorPtr rerr,
@@ -3142,6 +3145,7 @@ remoteDispatchAuthPolkit(virNetServerPtr server ATTRIBUTE_UNUSED,
ret->complete = 1;
virNetServerClientSetAuth(client, 0);
+ virNetServerTrackCompletedAuth(server);
virMutexUnlock(&priv->lock);
virCommandFree(cmd);
VIR_FREE(pkout);
@@ -3182,7 +3186,7 @@ authdeny:
}
#elif WITH_POLKIT0
static int
-remoteDispatchAuthPolkit(virNetServerPtr server ATTRIBUTE_UNUSED,
+remoteDispatchAuthPolkit(virNetServerPtr server,
virNetServerClientPtr client,
virNetMessagePtr msg ATTRIBUTE_UNUSED,
virNetMessageErrorPtr rerr,
@@ -3297,6 +3301,7 @@ remoteDispatchAuthPolkit(virNetServerPtr server ATTRIBUTE_UNUSED,
ret->complete = 1;
virNetServerClientSetAuth(client, 0);
+ virNetServerTrackCompletedAuth(server);
virMutexUnlock(&priv->lock);
VIR_FREE(ident);
return 0;
diff --git a/src/rpc/virnetserver.c b/src/rpc/virnetserver.c
index f70e260..3170f64 100644
--- a/src/rpc/virnetserver.c
+++ b/src/rpc/virnetserver.c
@@ -88,9 +88,10 @@ struct _virNetServer {
size_t nprograms;
virNetServerProgramPtr *programs;
- size_t nclients;
- size_t nclients_max;
- virNetServerClientPtr *clients;
+ size_t nclients; /* Current clients count */
+ virNetServerClientPtr *clients; /* Clients */
+ size_t nclients_max; /* Max allowed clients count */
+ size_t nclients_unauth; /* Unauthenticated clients count */
int keepaliveInterval;
unsigned int keepaliveCount;
@@ -118,6 +119,8 @@ static virClassPtr virNetServerClass;
static void virNetServerDispose(void *obj);
static void virNetServerUpdateServicesLocked(virNetServerPtr srv,
bool enabled);
+static inline size_t virNetServerTrackPendingAuthLocked(virNetServerPtr srv);
+static inline size_t virNetServerTrackCompletedAuthLocked(virNetServerPtr srv);
static int virNetServerOnceInit(void)
{
@@ -273,6 +276,9 @@ static int virNetServerAddClient(virNetServerPtr srv,
srv->clients[srv->nclients-1] = client;
virObjectRef(client);
+ if (virNetServerClientNeedAuth(client))
+ virNetServerTrackPendingAuthLocked(srv);
+
if (srv->nclients == srv->nclients_max) {
/* Temporarily stop accepting new clients */
VIR_DEBUG("Temporarily suspending services due to max_clients");
@@ -1140,6 +1146,9 @@ void virNetServerRun(virNetServerPtr srv)
srv->nclients = 0;
}
+ if (virNetServerClientNeedAuth(client))
+ virNetServerTrackCompletedAuthLocked(srv);
+
/* Enable services if we can accept a new client.
* The new client can be accepted if we are at the limit. */
if (srv->nclients == srv->nclients_max - 1) {
@@ -1236,3 +1245,33 @@ bool virNetServerKeepAliveRequired(virNetServerPtr srv)
virObjectUnlock(srv);
return required;
}
+
+static inline size_t
+virNetServerTrackPendingAuthLocked(virNetServerPtr srv)
+{
+ return ++srv->nclients_unauth;
+}
+
+static inline size_t
+virNetServerTrackCompletedAuthLocked(virNetServerPtr srv)
+{
+ return --srv->nclients_unauth;
+}
+
+size_t virNetServerTrackPendingAuth(virNetServerPtr srv)
+{
+ size_t ret;
+ virObjectLock(srv);
+ ret = virNetServerTrackPendingAuthLocked(srv);
+ virObjectUnlock(srv);
+ return ret;
+}
+
+size_t virNetServerTrackCompletedAuth(virNetServerPtr srv)
+{
+ size_t ret;
+ virObjectLock(srv);
+ ret = virNetServerTrackCompletedAuthLocked(srv);
+ virObjectUnlock(srv);
+ return ret;
+}
diff --git a/src/rpc/virnetserver.h b/src/rpc/virnetserver.h
index 1a85c02..b56540c 100644
--- a/src/rpc/virnetserver.h
+++ b/src/rpc/virnetserver.h
@@ -97,4 +97,7 @@ void virNetServerClose(virNetServerPtr srv);
bool virNetServerKeepAliveRequired(virNetServerPtr srv);
+size_t virNetServerTrackPendingAuth(virNetServerPtr srv);
+size_t virNetServerTrackCompletedAuth(virNetServerPtr srv);
+
#endif
--
1.9.0