If 2 threads call abort for example then one of them
will hang because client will send 2 abort messages and
server will reply only on first of them, the second will be
ignored. And on server reply client changes the state only
one of abort message to complete, the second will hang forever.
There are other similar issues.
We should complete all messages waiting reply if we got
error or expected abort/finish reply from server. Also if one
thread send finish and another abort one of them will win
the race and server will either abort or finish stream. If
stream is aborted then thread requested finishing should report
error. In order to archive this let's keep stream closing reason
in @closed field. If we receive VIR_NET_OK message for stream
then stream is finished if oldest (closest to queue end) message
in stream queue is finish message and stream is aborted if oldest
message is abort message. Otherwise it is protocol error.
By the way we need to fix case of receiving VIR_NET_CONTINUE
message. Now we take oldest message in queue and check if
this is dummy message. If one thread first sends abort and
second thread then receives data then oldest message is abort
message and second thread won't be notified when data arrives.
Let's find oldest dummy message instead.
Signed-off-by: Nikolay Shirokovskiy <nshirokovskiy(a)virtuozzo.com>
---
src/rpc/virnetclient.c | 74 ++++++++++++++++++++++++++++----------------
src/rpc/virnetclientstream.c | 47 +++++++++++++++++++++++++---
src/rpc/virnetclientstream.h | 9 ++++++
3 files changed, 100 insertions(+), 30 deletions(-)
diff --git a/src/rpc/virnetclient.c b/src/rpc/virnetclient.c
index 70192a9..64855fb 100644
--- a/src/rpc/virnetclient.c
+++ b/src/rpc/virnetclient.c
@@ -1158,6 +1158,19 @@ static int virNetClientCallDispatchMessage(virNetClientPtr client)
return 0;
}
+static void virNetClientCallCompleteAllWaitingReply(virNetClientPtr client)
+{
+ virNetClientCallPtr call;
+
+ for (call = client->waitDispatch; call; call = call->next) {
+ if (call->msg->header.prog == client->msg.header.prog &&
+ call->msg->header.vers == client->msg.header.vers &&
+ call->msg->header.serial == client->msg.header.serial &&
+ call->expectReply)
+ call->mode = VIR_NET_CLIENT_MODE_COMPLETE;
+ }
+}
+
static int virNetClientCallDispatchStream(virNetClientPtr client)
{
size_t i;
@@ -1181,16 +1194,6 @@ static int virNetClientCallDispatchStream(virNetClientPtr client)
return 0;
}
- /* Finish/Abort are synchronous, so also see if there's an
- * (optional) call waiting for this stream packet */
- thecall = client->waitDispatch;
- while (thecall &&
- !(thecall->msg->header.prog == client->msg.header.prog &&
- thecall->msg->header.vers == client->msg.header.vers &&
- thecall->msg->header.serial == client->msg.header.serial))
- thecall = thecall->next;
-
- VIR_DEBUG("Found call %p", thecall);
/* Status is either
* - VIR_NET_OK - no payload for streams
@@ -1202,25 +1205,47 @@ static int virNetClientCallDispatchStream(virNetClientPtr client)
if (virNetClientStreamQueuePacket(st, &client->msg) < 0)
return -1;
- if (thecall && thecall->expectReply) {
- if (thecall->msg->header.status == VIR_NET_CONTINUE) {
- VIR_DEBUG("Got a synchronous confirm");
- thecall->mode = VIR_NET_CLIENT_MODE_COMPLETE;
- } else {
- VIR_DEBUG("Not completing call with status %d",
thecall->msg->header.status);
- }
+ /* Find oldest dummy message waiting for incoming data. */
+ for (thecall = client->waitDispatch; thecall; thecall = thecall->next) {
+ if (thecall->msg->header.prog == client->msg.header.prog &&
+ thecall->msg->header.vers == client->msg.header.vers &&
+ thecall->msg->header.serial == client->msg.header.serial
&&
+ thecall->expectReply &&
+ thecall->msg->header.status == VIR_NET_CONTINUE)
+ break;
+ }
+
+ if (thecall) {
+ VIR_DEBUG("Got a new incoming stream data");
+ thecall->mode = VIR_NET_CLIENT_MODE_COMPLETE;
}
return 0;
}
case VIR_NET_OK:
- if (thecall && thecall->expectReply) {
- VIR_DEBUG("Got a synchronous confirm");
- thecall->mode = VIR_NET_CLIENT_MODE_COMPLETE;
- } else {
+ /* Find oldest abort/finish message. */
+ for (thecall = client->waitDispatch; thecall; thecall = thecall->next) {
+ if (thecall->msg->header.prog == client->msg.header.prog &&
+ thecall->msg->header.vers == client->msg.header.vers &&
+ thecall->msg->header.serial == client->msg.header.serial
&&
+ thecall->expectReply &&
+ thecall->msg->header.status != VIR_NET_CONTINUE)
+ break;
+ }
+
+ if (!thecall) {
VIR_DEBUG("Got unexpected async stream finish confirmation");
return -1;
}
+
+ VIR_DEBUG("Got a synchronous abort/finish confirm");
+
+ virNetClientStreamSetClosed(st,
+ thecall->msg->header.status == VIR_NET_OK ?
+ VIR_NET_CLIENT_STREAM_CLOSED_FINISHED :
+ VIR_NET_CLIENT_STREAM_CLOSED_ABORTED);
+
+ virNetClientCallCompleteAllWaitingReply(client);
return 0;
case VIR_NET_ERROR:
@@ -1228,10 +1253,7 @@ static int virNetClientCallDispatchStream(virNetClientPtr client)
if (virNetClientStreamSetError(st, &client->msg) < 0)
return -1;
- if (thecall && thecall->expectReply) {
- VIR_DEBUG("Got a synchronous error");
- thecall->mode = VIR_NET_CLIENT_MODE_COMPLETE;
- }
+ virNetClientCallCompleteAllWaitingReply(client);
return 0;
default:
@@ -2205,7 +2227,7 @@ int virNetClientSendStream(virNetClientPtr client,
if (virNetClientSendInternal(client, msg, expectReply, false) < 0)
goto cleanup;
- if (virNetClientStreamCheckSendStatus(st, msg) < 0)
+ if (expectReply && virNetClientStreamCheckSendStatus(st, msg) < 0)
goto cleanup;
ret = 0;
diff --git a/src/rpc/virnetclientstream.c b/src/rpc/virnetclientstream.c
index cfdaa74..583cd369 100644
--- a/src/rpc/virnetclientstream.c
+++ b/src/rpc/virnetclientstream.c
@@ -49,6 +49,7 @@ struct _virNetClientStream {
*/
virNetMessagePtr rx;
bool incomingEOF;
+ int closed; /* enum virNetClientStreamClosed */
bool allowSkip;
long long holeLength; /* Size of incoming hole in stream. */
@@ -84,7 +85,7 @@ virNetClientStreamEventTimerUpdate(virNetClientStreamPtr st)
VIR_DEBUG("Check timer rx=%p cbEvents=%d", st->rx, st->cbEvents);
- if (((st->rx || st->incomingEOF || st->err.code != VIR_ERR_OK) &&
+ if (((st->rx || st->incomingEOF || st->err.code != VIR_ERR_OK ||
st->closed) &&
(st->cbEvents & VIR_STREAM_EVENT_READABLE)) ||
(st->cbEvents & VIR_STREAM_EVENT_WRITABLE)) {
VIR_DEBUG("Enabling event timer");
@@ -106,7 +107,7 @@ virNetClientStreamEventTimer(int timer ATTRIBUTE_UNUSED, void
*opaque)
if (st->cb &&
(st->cbEvents & VIR_STREAM_EVENT_READABLE) &&
- (st->rx || st->incomingEOF || st->err.code != VIR_ERR_OK))
+ (st->rx || st->incomingEOF || st->err.code != VIR_ERR_OK ||
st->closed))
events |= VIR_STREAM_EVENT_READABLE;
if (st->cb &&
(st->cbEvents & VIR_STREAM_EVENT_WRITABLE))
@@ -203,23 +204,61 @@ int virNetClientStreamCheckState(virNetClientStreamPtr st)
return -1;
}
+ if (st->closed) {
+ virReportError(VIR_ERR_OPERATION_FAILED, "%s",
+ _("stream is closed"));
+ return -1;
+ }
+
return 0;
}
-/* MUST be called under stream or client lock */
+/* MUST be called under stream or client lock. This should
+ * be called only for message that expect reply. */
int virNetClientStreamCheckSendStatus(virNetClientStreamPtr st,
- virNetMessagePtr msg ATTRIBUTE_UNUSED)
+ virNetMessagePtr msg)
{
if (st->err.code != VIR_ERR_OK) {
virNetClientStreamRaiseError(st);
return -1;
}
+ /* We can not check if the message is dummy in a usual way
+ * by checking msg->bufferLength because at this point message payload
+ * is cleared. As caller must not call this function for messages
+ * not expecting reply we can check for dummy messages just by status.
+ */
+ if (msg->header.status == VIR_NET_CONTINUE) {
+ if (st->closed) {
+ virReportError(VIR_ERR_OPERATION_FAILED, "%s",
+ _("stream is closed"));
+ return -1;
+ }
+ return 0;
+ } else if (msg->header.status == VIR_NET_OK &&
+ st->closed != VIR_NET_CLIENT_STREAM_CLOSED_FINISHED) {
+ virReportError(VIR_ERR_OPERATION_FAILED, "%s",
+ _("stream aborted by another thread"));
+ return -1;
+ }
+
return 0;
}
+void virNetClientStreamSetClosed(virNetClientStreamPtr st,
+ int closed)
+{
+ virObjectLock(st);
+
+ st->closed = closed;
+ virNetClientStreamEventTimerUpdate(st);
+
+ virObjectUnlock(st);
+}
+
+
int virNetClientStreamSetError(virNetClientStreamPtr st,
virNetMessagePtr msg)
{
diff --git a/src/rpc/virnetclientstream.h b/src/rpc/virnetclientstream.h
index 49b74bc..cb28428 100644
--- a/src/rpc/virnetclientstream.h
+++ b/src/rpc/virnetclientstream.h
@@ -27,6 +27,12 @@
typedef struct _virNetClientStream virNetClientStream;
typedef virNetClientStream *virNetClientStreamPtr;
+typedef enum {
+ VIR_NET_CLIENT_STREAM_CLOSED_NOT = 0,
+ VIR_NET_CLIENT_STREAM_CLOSED_FINISHED,
+ VIR_NET_CLIENT_STREAM_CLOSED_ABORTED,
+} virNetClientStreamClosed;
+
typedef void (*virNetClientStreamEventCallback)(virNetClientStreamPtr stream,
int events, void *opaque);
@@ -43,6 +49,9 @@ int virNetClientStreamCheckSendStatus(virNetClientStreamPtr st,
int virNetClientStreamSetError(virNetClientStreamPtr st,
virNetMessagePtr msg);
+void virNetClientStreamSetClosed(virNetClientStreamPtr st,
+ int closed);
+
bool virNetClientStreamMatches(virNetClientStreamPtr st,
virNetMessagePtr msg);
--
1.8.3.1