Currently the socket code will unlink any UNIX socket path which is
associated with a server socket. This is not fine grained enough, as we
need to avoid unlinking server sockets we were passed by systemd.
Signed-off-by: Daniel P. Berrangé <berrange(a)redhat.com>
---
src/locking/lock_daemon.c | 1 +
src/logging/log_daemon.c | 1 +
src/rpc/virnetserverservice.c | 3 ++
src/rpc/virnetserverservice.h | 1 +
src/rpc/virnetsocket.c | 57 ++++++++++++++++++++---------------
src/rpc/virnetsocket.h | 1 +
6 files changed, 40 insertions(+), 24 deletions(-)
diff --git a/src/locking/lock_daemon.c b/src/locking/lock_daemon.c
index c10b2d383c..0f90606be6 100644
--- a/src/locking/lock_daemon.c
+++ b/src/locking/lock_daemon.c
@@ -619,6 +619,7 @@ virLockDaemonSetupNetworkingSystemD(virNetServerPtr lockSrv,
virNetServerPtr adm
* so the first FD we'll get is '3'. */
if (!(svc = virNetServerServiceNewFDs(fds,
ARRAY_CARDINALITY(fds),
+ false,
0,
NULL,
false, 0, 1)))
diff --git a/src/logging/log_daemon.c b/src/logging/log_daemon.c
index 6531999381..30c70a20dd 100644
--- a/src/logging/log_daemon.c
+++ b/src/logging/log_daemon.c
@@ -554,6 +554,7 @@ virLogDaemonSetupNetworkingSystemD(virNetServerPtr logSrv,
virNetServerPtr admin
* so the first FD we'll get is '3'. */
if (!(svc = virNetServerServiceNewFDs(fds,
ARRAY_CARDINALITY(fds),
+ false,
0,
NULL,
false, 0, 1)))
diff --git a/src/rpc/virnetserverservice.c b/src/rpc/virnetserverservice.c
index 0d2f264696..315a4950df 100644
--- a/src/rpc/virnetserverservice.c
+++ b/src/rpc/virnetserverservice.c
@@ -121,6 +121,7 @@ virNetServerServiceNewFDOrUNIX(const char *path,
*/
return virNetServerServiceNewFDs(fds,
ARRAY_CARDINALITY(fds),
+ false,
auth,
tls,
readonly,
@@ -257,6 +258,7 @@ virNetServerServicePtr virNetServerServiceNewUNIX(const char *path,
virNetServerServicePtr virNetServerServiceNewFDs(int *fds,
size_t nfds,
+ bool unlinkUNIX,
int auth,
virNetTLSContextPtr tls,
bool readonly,
@@ -272,6 +274,7 @@ virNetServerServicePtr virNetServerServiceNewFDs(int *fds,
for (i = 0; i < nfds; i++) {
if (virNetSocketNewListenFD(fds[i],
+ unlinkUNIX,
&socks[i]) < 0)
goto cleanup;
}
diff --git a/src/rpc/virnetserverservice.h b/src/rpc/virnetserverservice.h
index 59ee51e5ee..73d61dde99 100644
--- a/src/rpc/virnetserverservice.h
+++ b/src/rpc/virnetserverservice.h
@@ -62,6 +62,7 @@ virNetServerServicePtr virNetServerServiceNewUNIX(const char *path,
size_t nrequests_client_max);
virNetServerServicePtr virNetServerServiceNewFDs(int *fd,
size_t nfds,
+ bool unlinkUNIX,
int auth,
virNetTLSContextPtr tls,
bool readonly,
diff --git a/src/rpc/virnetsocket.c b/src/rpc/virnetsocket.c
index fc13b1654a..a462c3eb05 100644
--- a/src/rpc/virnetsocket.c
+++ b/src/rpc/virnetsocket.c
@@ -81,6 +81,7 @@ struct _virNetSocket {
bool client;
bool ownsFd;
bool quietEOF;
+ bool unlinkUNIX;
/* Event callback fields */
virNetSocketIOFunc func;
@@ -216,10 +217,13 @@ int virNetSocketCheckProtocols(bool *hasIPv4,
}
-static virNetSocketPtr virNetSocketNew(virSocketAddrPtr localAddr,
- virSocketAddrPtr remoteAddr,
- bool isClient,
- int fd, int errfd, pid_t pid)
+static virNetSocketPtr
+virNetSocketNew(virSocketAddrPtr localAddr,
+ virSocketAddrPtr remoteAddr,
+ int fd,
+ int errfd,
+ pid_t pid,
+ bool unlinkUNIX)
{
virNetSocketPtr sock;
int no_slow_start = 1;
@@ -254,6 +258,7 @@ static virNetSocketPtr virNetSocketNew(virSocketAddrPtr localAddr,
sock->pid = pid;
sock->watch = -1;
sock->ownsFd = true;
+ sock->unlinkUNIX = unlinkUNIX;
/* Disable nagle for TCP sockets */
if (sock->localAddr.data.sa.sa_family == AF_INET ||
@@ -280,8 +285,6 @@ static virNetSocketPtr virNetSocketNew(virSocketAddrPtr localAddr,
!(sock->remoteAddrStrURI = virSocketAddrFormatFull(remoteAddr, true, NULL)))
goto error;
- sock->client = isClient;
-
PROBE(RPC_SOCKET_NEW,
"sock=%p fd=%d errfd=%d pid=%lld localAddr=%s, remoteAddr=%s",
sock, fd, errfd, (long long)pid,
@@ -427,7 +430,7 @@ int virNetSocketNewListenTCP(const char *nodename,
if (VIR_EXPAND_N(socks, nsocks, 1) < 0)
goto error;
- if (!(socks[nsocks-1] = virNetSocketNew(&addr, NULL, false, fd, -1, 0)))
+ if (!(socks[nsocks-1] = virNetSocketNew(&addr, NULL, fd, -1, 0, false)))
goto error;
runp = runp->ai_next;
fd = -1;
@@ -513,7 +516,7 @@ int virNetSocketNewListenUNIX(const char *path,
goto error;
}
- if (!(*retsock = virNetSocketNew(&addr, NULL, false, fd, -1, 0)))
+ if (!(*retsock = virNetSocketNew(&addr, NULL, fd, -1, 0, true)))
goto error;
return 0;
@@ -538,6 +541,7 @@ int virNetSocketNewListenUNIX(const char *path ATTRIBUTE_UNUSED,
#endif
int virNetSocketNewListenFD(int fd,
+ bool unlinkUNIX,
virNetSocketPtr *retsock)
{
virSocketAddr addr;
@@ -551,7 +555,7 @@ int virNetSocketNewListenFD(int fd,
return -1;
}
- if (!(*retsock = virNetSocketNew(&addr, NULL, false, fd, -1, 0)))
+ if (!(*retsock = virNetSocketNew(&addr, NULL, fd, -1, 0, unlinkUNIX)))
return -1;
return 0;
@@ -627,7 +631,7 @@ int virNetSocketNewConnectTCP(const char *nodename,
goto error;
}
- if (!(*retsock = virNetSocketNew(&localAddr, &remoteAddr, true, fd, -1, 0)))
+ if (!(*retsock = virNetSocketNew(&localAddr, &remoteAddr, fd, -1, 0,
false)))
goto error;
freeaddrinfo(ai);
@@ -752,7 +756,7 @@ int virNetSocketNewConnectUNIX(const char *path,
goto cleanup;
}
- if (!(*retsock = virNetSocketNew(&localAddr, &remoteAddr, true, fd, -1, 0)))
+ if (!(*retsock = virNetSocketNew(&localAddr, &remoteAddr, fd, -1, 0,
false)))
goto cleanup;
ret = 0;
@@ -820,7 +824,7 @@ int virNetSocketNewConnectCommand(virCommandPtr cmd,
VIR_FORCE_CLOSE(sv[1]);
VIR_FORCE_CLOSE(errfd[1]);
- if (!(*retsock = virNetSocketNew(NULL, NULL, true, sv[0], errfd[0], pid)))
+ if (!(*retsock = virNetSocketNew(NULL, NULL, sv[0], errfd[0], pid, false)))
goto error;
virCommandFree(cmd);
@@ -1219,7 +1223,7 @@ int virNetSocketNewConnectSockFD(int sockfd,
return -1;
}
- if (!(*retsock = virNetSocketNew(&localAddr, NULL, true, sockfd, -1, -1)))
+ if (!(*retsock = virNetSocketNew(&localAddr, NULL, sockfd, -1, -1, false)))
return -1;
return 0;
@@ -1231,7 +1235,7 @@ virNetSocketPtr virNetSocketNewPostExecRestart(virJSONValuePtr
object)
virSocketAddr localAddr;
virSocketAddr remoteAddr;
int fd, thepid, errfd;
- bool isClient;
+ bool unlinkUNIX;
if (virJSONValueObjectGetNumberInt(object, "fd", &fd) < 0) {
virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
@@ -1250,10 +1254,15 @@ virNetSocketPtr virNetSocketNewPostExecRestart(virJSONValuePtr
object)
_("Missing errfd data in JSON document"));
return NULL;
}
- if (virJSONValueObjectGetBoolean(object, "isClient", &isClient) < 0)
{
- virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
- _("Missing isClient data in JSON document"));
- return NULL;
+
+ if (virJSONValueObjectGetBoolean(object, "unlinkUNIX", &unlinkUNIX)
< 0) {
+ bool isClient;
+ if (virJSONValueObjectGetBoolean(object, "isClient", &isClient)
< 0) {
+ virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
+ _("Missing unlinkUNIX/isClient data in JSON
document"));
+ return NULL;
+ }
+ unlinkUNIX = !isClient;
}
memset(&localAddr, 0, sizeof(localAddr));
@@ -1272,7 +1281,7 @@ virNetSocketPtr virNetSocketNewPostExecRestart(virJSONValuePtr
object)
}
return virNetSocketNew(&localAddr, &remoteAddr,
- isClient, fd, errfd, thepid);
+ fd, errfd, thepid, unlinkUNIX);
}
@@ -1309,7 +1318,7 @@ virJSONValuePtr virNetSocketPreExecRestart(virNetSocketPtr sock)
if (virJSONValueObjectAppendNumberInt(object, "pid", sock->pid) < 0)
goto error;
- if (virJSONValueObjectAppendBoolean(object, "isClient", sock->client)
< 0)
+ if (virJSONValueObjectAppendBoolean(object, "unlinkUNIX",
sock->unlinkUNIX) < 0)
goto error;
if (virSetInherit(sock->fd, true) < 0) {
@@ -1350,7 +1359,7 @@ void virNetSocketDispose(void *obj)
#ifdef HAVE_SYS_UN_H
/* If a server socket, then unlink UNIX path */
- if (!sock->client &&
+ if (sock->unlinkUNIX &&
sock->localAddr.data.sa.sa_family == AF_UNIX &&
sock->localAddr.data.un.sun_path[0] != '\0')
unlink(sock->localAddr.data.un.sun_path);
@@ -2140,8 +2149,8 @@ int virNetSocketAccept(virNetSocketPtr sock, virNetSocketPtr
*clientsock)
if (!(*clientsock = virNetSocketNew(&localAddr,
&remoteAddr,
- true,
- fd, -1, 0)))
+ fd, -1, 0,
+ false)))
goto cleanup;
fd = -1;
@@ -2272,7 +2281,7 @@ void virNetSocketClose(virNetSocketPtr sock)
#ifdef HAVE_SYS_UN_H
/* If a server socket, then unlink UNIX path */
- if (!sock->client &&
+ if (sock->unlinkUNIX &&
sock->localAddr.data.sa.sa_family == AF_UNIX &&
sock->localAddr.data.un.sun_path[0] != '\0') {
if (unlink(sock->localAddr.data.un.sun_path) == 0)
diff --git a/src/rpc/virnetsocket.h b/src/rpc/virnetsocket.h
index de5a465cde..2f626cb08f 100644
--- a/src/rpc/virnetsocket.h
+++ b/src/rpc/virnetsocket.h
@@ -58,6 +58,7 @@ int virNetSocketNewListenUNIX(const char *path,
virNetSocketPtr *addr);
int virNetSocketNewListenFD(int fd,
+ bool unlinkUNIX,
virNetSocketPtr *addr);
int virNetSocketNewConnectTCP(const char *nodename,
--
2.21.0