From: "Daniel P. Berrange" <berrange(a)redhat.com>
The code calling sendfd/recvfd was mistakenly assuming those
calls would never block. They can in fact return EAGAIN and
this is causing us to drop the client connection when blocking
ocurrs while sending/receiving FDs.
Fixing this is a little hairy on the incoming side, since at
the point where we see the EAGAIN, we already thought we had
finished receiving all data for the packet. So we play a little
trick to reset bufferOffset again and go back into polling for
more data.
* src/rpc/virnetsocket.c, src/rpc/virnetsocket.h: Update
virNetSocketSendFD/RecvFD to return 0 on EAGAIN, or 1
on success
* src/rpc/virnetclient.c: Move decoding of header & fds
out of virNetClientCallDispatch and into virNetClientIOHandleInput.
Handling blocking when sending/receiving FDs
* src/rpc/virnetmessage.h: Add a 'donefds' field to track
how many FDs we've sent / received
* src/rpc/virnetserverclient.c: Handling blocking when
sending/receiving FDs
---
src/rpc/virnetclient.c | 79 ++++++++++++++++++++++++++++--------------
src/rpc/virnetmessage.h | 1 +
src/rpc/virnetserverclient.c | 62 ++++++++++++++++++++++++---------
src/rpc/virnetsocket.c | 34 +++++++++++++-----
src/rpc/virnetsocket.h | 2 +-
5 files changed, 125 insertions(+), 53 deletions(-)
diff --git a/src/rpc/virnetclient.c b/src/rpc/virnetclient.c
index 2b5f67c..4b7d4a9 100644
--- a/src/rpc/virnetclient.c
+++ b/src/rpc/virnetclient.c
@@ -694,10 +694,6 @@ static int virNetClientCallDispatchStream(virNetClientPtr client)
static int
virNetClientCallDispatch(virNetClientPtr client)
{
- size_t i;
- if (virNetMessageDecodeHeader(&client->msg) < 0)
- return -1;
-
PROBE(RPC_CLIENT_MSG_RX,
"client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u
serial=%u",
client, client->msg.bufferLength,
@@ -706,15 +702,7 @@ virNetClientCallDispatch(virNetClientPtr client)
switch (client->msg.header.type) {
case VIR_NET_REPLY: /* Normal RPC replies */
- return virNetClientCallDispatchReply(client);
-
case VIR_NET_REPLY_WITH_FDS: /* Normal RPC replies with FDs */
- if (virNetMessageDecodeNumFDs(&client->msg) < 0)
- return -1;
- for (i = 0 ; i < client->msg.nfds ; i++) {
- if ((client->msg.fds[i] = virNetSocketRecvFD(client->sock)) < 0)
- return -1;
- }
return virNetClientCallDispatchReply(client);
case VIR_NET_MESSAGE: /* Async notifications */
@@ -737,22 +725,29 @@ static ssize_t
virNetClientIOWriteMessage(virNetClientPtr client,
virNetClientCallPtr thecall)
{
- ssize_t ret;
+ ssize_t ret = 0;
- ret = virNetSocketWrite(client->sock,
- thecall->msg->buffer +
thecall->msg->bufferOffset,
- thecall->msg->bufferLength -
thecall->msg->bufferOffset);
- if (ret <= 0)
- return ret;
+ if (thecall->msg->bufferOffset < thecall->msg->bufferLength) {
+ ret = virNetSocketWrite(client->sock,
+ thecall->msg->buffer +
thecall->msg->bufferOffset,
+ thecall->msg->bufferLength -
thecall->msg->bufferOffset);
+ if (ret <= 0)
+ return ret;
- thecall->msg->bufferOffset += ret;
+ thecall->msg->bufferOffset += ret;
+ }
if (thecall->msg->bufferOffset == thecall->msg->bufferLength) {
size_t i;
- for (i = 0 ; i < thecall->msg->nfds ; i++) {
- if (virNetSocketSendFD(client->sock, thecall->msg->fds[i]) < 0)
+ for (i = thecall->msg->donefds ; i < thecall->msg->nfds ; i++) {
+ int rv;
+ if ((rv = virNetSocketSendFD(client->sock, thecall->msg->fds[i]))
< 0)
return -1;
+ if (rv == 0) /* Blocking */
+ return 0;
+ thecall->msg->donefds++;
}
+ thecall->msg->donefds = 0;
thecall->msg->bufferOffset = thecall->msg->bufferLength = 0;
if (thecall->expectReply)
thecall->mode = VIR_NET_CLIENT_MODE_WAIT_RX;
@@ -821,12 +816,16 @@ virNetClientIOHandleInput(virNetClientPtr client)
* EAGAIN
*/
for (;;) {
- ssize_t ret = virNetClientIOReadMessage(client);
+ ssize_t ret;
- if (ret < 0)
- return -1;
- if (ret == 0)
- return 0; /* Blocking on read */
+ if (client->msg.nfds == 0) {
+ ret = virNetClientIOReadMessage(client);
+
+ if (ret < 0)
+ return -1;
+ if (ret == 0)
+ return 0; /* Blocking on read */
+ }
/* Check for completion of our goal */
if (client->msg.bufferOffset == client->msg.bufferLength) {
@@ -842,6 +841,33 @@ virNetClientIOHandleInput(virNetClientPtr client)
* next iteration.
*/
} else {
+ if (virNetMessageDecodeHeader(&client->msg) < 0)
+ return -1;
+
+ if (client->msg.header.type == VIR_NET_REPLY_WITH_FDS) {
+ size_t i;
+ if (virNetMessageDecodeNumFDs(&client->msg) < 0)
+ return -1;
+
+ for (i = client->msg.donefds ; i < client->msg.nfds ; i++)
{
+ int rv;
+ if ((rv = virNetSocketRecvFD(client->sock,
&(client->msg.fds[i]))) < 0)
+ return -1;
+ if (rv == 0) /* Blocking */
+ break;
+ client->msg.donefds++;
+ }
+
+ if (client->msg.donefds < client->msg.nfds) {
+ /* Because DecodeHeader/NumFDs reset bufferOffset, we
+ * put it back to what it was, so everything works
+ * again next time we run this method
+ */
+ client->msg.bufferOffset = client->msg.bufferLength;
+ return 0; /* Blocking on more fds */
+ }
+ }
+
ret = virNetClientCallDispatch(client);
client->msg.bufferOffset = client->msg.bufferLength = 0;
/*
@@ -1257,6 +1283,7 @@ int virNetClientSend(virNetClientPtr client,
goto cleanup;
}
+ msg->donefds = 0;
if (msg->bufferLength)
call->mode = VIR_NET_CLIENT_MODE_WAIT_TX;
else
diff --git a/src/rpc/virnetmessage.h b/src/rpc/virnetmessage.h
index ad63409..c54e7c6 100644
--- a/src/rpc/virnetmessage.h
+++ b/src/rpc/virnetmessage.h
@@ -48,6 +48,7 @@ struct _virNetMessage {
size_t nfds;
int *fds;
+ size_t donefds;
virNetMessagePtr next;
};
diff --git a/src/rpc/virnetserverclient.c b/src/rpc/virnetserverclient.c
index 2f5ae8f..cf97b58 100644
--- a/src/rpc/virnetserverclient.c
+++ b/src/rpc/virnetserverclient.c
@@ -771,9 +771,11 @@ static ssize_t virNetServerClientRead(virNetServerClientPtr client)
static void virNetServerClientDispatchRead(virNetServerClientPtr client)
{
readmore:
- if (virNetServerClientRead(client) < 0) {
- client->wantClose = true;
- return; /* Error */
+ if (client->rx->nfds == 0) {
+ if (virNetServerClientRead(client) < 0) {
+ client->wantClose = true;
+ return; /* Error */
+ }
}
if (client->rx->bufferOffset < client->rx->bufferLength)
@@ -794,7 +796,7 @@ readmore:
goto readmore;
} else {
/* Grab the completed message */
- virNetMessagePtr msg = virNetMessageQueueServe(&client->rx);
+ virNetMessagePtr msg = client->rx;
virNetServerClientFilterPtr filter;
size_t i;
@@ -805,20 +807,40 @@ readmore:
return;
}
+ /* Now figure out if we need to read more data to get some
+ * file descriptors */
if (msg->header.type == VIR_NET_CALL_WITH_FDS &&
virNetMessageDecodeNumFDs(msg) < 0) {
virNetMessageFree(msg);
client->wantClose = true;
- return;
+ return; /* Error */
}
- for (i = 0 ; i < msg->nfds ; i++) {
- if ((msg->fds[i] = virNetSocketRecvFD(client->sock)) < 0) {
+
+ /* Try getting the file descriptors (may fail if blocking) */
+ for (i = msg->donefds ; i < msg->nfds ; i++) {
+ int rv;
+ if ((rv = virNetSocketRecvFD(client->sock, &(msg->fds[i]))) < 0)
{
virNetMessageFree(msg);
client->wantClose = true;
return;
}
+ if (rv == 0) /* Blocking */
+ break;
+ msg->donefds++;
+ }
+
+ /* Need to poll() until FDs arrive */
+ if (msg->donefds < msg->nfds) {
+ /* Because DecodeHeader/NumFDs reset bufferOffset, we
+ * put it back to what it was, so everything works
+ * again next time we run this method
+ */
+ client->rx->bufferOffset = client->rx->bufferLength;
+ return;
}
+ /* Definitely finished reading, so remove from queue */
+ virNetMessageQueueServe(&client->rx);
PROBE(RPC_SERVER_CLIENT_MSG_RX,
"client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u
serial=%u",
client, msg->bufferLength,
@@ -912,25 +934,30 @@ static void
virNetServerClientDispatchWrite(virNetServerClientPtr client)
{
while (client->tx) {
- ssize_t ret;
-
- ret = virNetServerClientWrite(client);
- if (ret < 0) {
- client->wantClose = true;
- return;
+ if (client->tx->bufferOffset < client->tx->bufferLength) {
+ ssize_t ret;
+ ret = virNetServerClientWrite(client);
+ if (ret < 0) {
+ client->wantClose = true;
+ return;
+ }
+ if (ret == 0)
+ return; /* Would block on write EAGAIN */
}
- if (ret == 0)
- return; /* Would block on write EAGAIN */
if (client->tx->bufferOffset == client->tx->bufferLength) {
virNetMessagePtr msg;
size_t i;
- for (i = 0 ; i < client->tx->nfds ; i++) {
- if (virNetSocketSendFD(client->sock, client->tx->fds[i]) < 0)
{
+ for (i = client->tx->donefds ; i < client->tx->nfds ; i++) {
+ int rv;
+ if ((rv = virNetSocketSendFD(client->sock, client->tx->fds[i]))
< 0) {
client->wantClose = true;
return;
}
+ if (rv == 0) /* Blocking */
+ return;
+ client->tx->donefds++;
}
#if HAVE_SASL
@@ -1041,6 +1068,7 @@ int virNetServerClientSendMessage(virNetServerClientPtr client,
msg->bufferLength, msg->bufferOffset);
virNetServerClientLock(client);
+ msg->donefds = 0;
if (client->sock && !client->wantClose) {
PROBE(RPC_SERVER_CLIENT_MSG_TX_QUEUE,
"client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u
serial=%u",
diff --git a/src/rpc/virnetsocket.c b/src/rpc/virnetsocket.c
index d832c53..4517d16 100644
--- a/src/rpc/virnetsocket.c
+++ b/src/rpc/virnetsocket.c
@@ -1142,6 +1142,9 @@ ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf,
size_t len)
}
+/*
+ * Returns 1 if an FD was sent, 0 if it would block, -1 on error
+ */
int virNetSocketSendFD(virNetSocketPtr sock, int fd)
{
int ret = -1;
@@ -1154,12 +1157,15 @@ int virNetSocketSendFD(virNetSocketPtr sock, int fd)
PROBE(RPC_SOCKET_SEND_FD,
"sock=%p fd=%d", sock, fd);
if (sendfd(sock->fd, fd) < 0) {
- virReportSystemError(errno,
- _("Failed to send file descriptor %d"),
- fd);
+ if (errno == EAGAIN)
+ ret = 0;
+ else
+ virReportSystemError(errno,
+ _("Failed to send file descriptor %d"),
+ fd);
goto cleanup;
}
- ret = 0;
+ ret = 1;
cleanup:
virMutexUnlock(&sock->lock);
@@ -1167,9 +1173,15 @@ cleanup:
}
-int virNetSocketRecvFD(virNetSocketPtr sock)
+/*
+ * Returns 1 if an FD was read, 0 if it would block, -1 on error
+ */
+int virNetSocketRecvFD(virNetSocketPtr sock, int *fd)
{
int ret = -1;
+
+ *fd = -1;
+
if (!virNetSocketHasPassFD(sock)) {
virNetError(VIR_ERR_INTERNAL_ERROR,
_("Receiving file descriptors is not supported on this
socket"));
@@ -1177,13 +1189,17 @@ int virNetSocketRecvFD(virNetSocketPtr sock)
}
virMutexLock(&sock->lock);
- if ((ret = recvfd(sock->fd, O_CLOEXEC)) < 0) {
- virReportSystemError(errno, "%s",
- _("Failed to recv file descriptor"));
+ if ((*fd = recvfd(sock->fd, O_CLOEXEC)) < 0) {
+ if (errno == EAGAIN)
+ ret = 0;
+ else
+ virReportSystemError(errno, "%s",
+ _("Failed to recv file descriptor"));
goto cleanup;
}
PROBE(RPC_SOCKET_RECV_FD,
- "sock=%p fd=%d", sock, ret);
+ "sock=%p fd=%d", sock, *fd);
+ ret = 1;
cleanup:
virMutexUnlock(&sock->lock);
diff --git a/src/rpc/virnetsocket.h b/src/rpc/virnetsocket.h
index 13cbb14..e444aef 100644
--- a/src/rpc/virnetsocket.h
+++ b/src/rpc/virnetsocket.h
@@ -97,7 +97,7 @@ ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len);
ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len);
int virNetSocketSendFD(virNetSocketPtr sock, int fd);
-int virNetSocketRecvFD(virNetSocketPtr sock);
+int virNetSocketRecvFD(virNetSocketPtr sock, int *fd);
void virNetSocketSetTLSSession(virNetSocketPtr sock,
virNetTLSSessionPtr sess);
--
1.7.6.4