[libvirt] [PATCH 00/10] New internal infrastructure for RPC

This is an update of the series posted http://www.redhat.com/archives/libvir-list/2010-December/msg00616.html Since that time - Soooooooooo many bug fixes and much testing - Addition of mDNS APIs for the server - Integration of streams support in client The actual conversion of the remote driver + libvirtd has been removed from this series. I would like to get the infrastructure all reviewed & merged, then will repost the actual conversions.

This provides a new struct that contains a buffer for the RPC message header+payload, as well as a decoded copy of the message header. There is an API for applying a XDR encoding & decoding of the message headers and payloads. There are also APIs for maintaining a simple FIFO queue of message instances. Expected usage scenarios are: To send a message msg = virNetMessageNew() ...fill in msg->header fields.. virNetMessageEncodeHeader(msg) ...loook at msg->header fields to determine payload filter virNetMessageEncodePayload(msg, xdrfilter, data) ...send msg->bufferLength worth of data from buffer To receive a message msg = virNetMessageNew() ...read VIR_NET_MESSAGE_LEN_MAX of data into buffer virNetMessageDecodeLength(msg) ...read msg->bufferLength-msg->bufferOffset of data into buffer virNetMessageDecodeHeader(msg) ...look at msg->header fields to determine payload filter virNetMessageDecodePayload(msg, xdrfilter, data) ...run payload processor * src/Makefile.am: Add to libvirt-net-rpc.la * src/rpc/virnetmessage.c, src/rpc/virnetmessage.h: Internal message handling API. --- po/POTFILES.in | 1 + src/Makefile.am | 1 + src/rpc/virnetmessage.c | 365 +++++++++++++++++++++++++++++++++++++++++++++++ src/rpc/virnetmessage.h | 70 +++++++++ 4 files changed, 437 insertions(+), 0 deletions(-) create mode 100644 src/rpc/virnetmessage.c create mode 100644 src/rpc/virnetmessage.h diff --git a/po/POTFILES.in b/po/POTFILES.in index 1ed2765..65f4fc3 100644 --- a/po/POTFILES.in +++ b/po/POTFILES.in @@ -64,6 +64,7 @@ src/qemu/qemu_monitor_json.c src/qemu/qemu_monitor_text.c src/qemu/qemu_process.c src/remote/remote_driver.c +src/rpc/virnetmessage.c src/secret/secret_driver.c src/security/security_apparmor.c src/security/security_dac.c diff --git a/src/Makefile.am b/src/Makefile.am index fc31e5d..7b9cdd3 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -1202,6 +1202,7 @@ EXTRA_DIST += $(LIBVIRT_QEMU_SYMBOL_FILE) noinst_LTLIBRARIES += libvirt-net-rpc.la libvirt_net_rpc_la_SOURCES = \ + rpc/virnetmessage.h rpc/virnetmessage.c \ rpc/virnetprotocol.h rpc/virnetprotocol.c libvirt_net_rpc_la_CFLAGS = \ $(AM_CFLAGS) diff --git a/src/rpc/virnetmessage.c b/src/rpc/virnetmessage.c new file mode 100644 index 0000000..4c226d2 --- /dev/null +++ b/src/rpc/virnetmessage.c @@ -0,0 +1,365 @@ +/* + * virnetmessage.h: basic RPC message encoding/decoding + * + * Copyright (C) 2010 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + */ + +#include <config.h> + +#include <stdlib.h> + +#include "virnetmessage.h" +#include "memory.h" +#include "virterror_internal.h" +#include "logging.h" + +#define VIR_FROM_THIS VIR_FROM_RPC +#define virNetError(code, ...) \ + virReportErrorHelper(NULL, VIR_FROM_RPC, code, __FILE__, \ + __FUNCTION__, __LINE__, __VA_ARGS__) + +virNetMessagePtr virNetMessageNew(void) +{ + virNetMessagePtr msg; + + if (VIR_ALLOC(msg) < 0) { + virReportOOMError(); + return NULL; + } + + VIR_DEBUG("msg=%p", msg); + + return msg; +} + +void virNetMessageFree(virNetMessagePtr msg) +{ + if (!msg) + return; + + VIR_DEBUG("msg=%p", msg); + + VIR_FREE(msg); +} + +void virNetMessageQueuePush(virNetMessagePtr *queue, virNetMessagePtr msg) +{ + virNetMessagePtr tmp = *queue; + + if (tmp) { + while (tmp->next) + tmp = tmp->next; + tmp->next = msg; + } else { + *queue = msg; + } +} + + +virNetMessagePtr virNetMessageQueueServe(virNetMessagePtr *queue) +{ + virNetMessagePtr tmp = *queue; + + if (tmp) { + *queue = tmp->next; + tmp->next = NULL; + } + + return tmp; +} + + +int virNetMessageDecodeLength(virNetMessagePtr msg) +{ + XDR xdr; + unsigned int len; + int ret = -1; + + xdrmem_create(&xdr, msg->buffer, + msg->bufferLength, XDR_DECODE); + if (!xdr_u_int(&xdr, &len)) { + virNetError(VIR_ERR_RPC, "%s", _("Unable to decode message length")); + goto cleanup; + } + msg->bufferOffset = xdr_getpos(&xdr); + + if (len < VIR_NET_MESSAGE_LEN_MAX) { + virNetError(VIR_ERR_RPC, "%s", + _("packet received from server too small")); + goto cleanup; + } + + /* Length includes length word - adjust to real length to read. */ + len -= VIR_NET_MESSAGE_LEN_MAX; + + if (len > VIR_NET_MESSAGE_MAX) { + virNetError(VIR_ERR_RPC, "%s", + _("packet received from server too large")); + goto cleanup; + } + + /* Extend our declared buffer length and carry + on reading the header + payload */ + msg->bufferLength += len; + + VIR_DEBUG("Got length, now need %zu total (%u more)", + msg->bufferLength, len); + + ret = 0; + +cleanup: + xdr_destroy(&xdr); + return ret; +} + + +/* + * @msg: the complete incoming message, whose header to decode + * + * Decodes the header part of the message, but does not + * validate the decoded fields in the header. It expects + * bufferLength to refer to length of the data packet. Upon + * return bufferOffset will refer to the amount of the packet + * consumed by decoding of the header. + * + * returns 0 if successfully decoded, -1 upon fatal error + */ +int virNetMessageDecodeHeader(virNetMessagePtr msg) +{ + XDR xdr; + int ret = -1; + + msg->bufferOffset = VIR_NET_MESSAGE_LEN_MAX; + + /* Parse the header. */ + xdrmem_create(&xdr, + msg->buffer + msg->bufferOffset, + msg->bufferLength - msg->bufferOffset, + XDR_DECODE); + + if (!xdr_virNetMessageHeader(&xdr, &msg->header)) { + virNetError(VIR_ERR_RPC, "%s", _("Unable to decode message header")); + goto cleanup; + } + + msg->bufferOffset += xdr_getpos(&xdr); + + ret = 0; + +cleanup: + xdr_destroy(&xdr); + return ret; +} + + +/* + * @msg: the outgoing message, whose header to encode + * + * Encodes the length word and header of the message, setting the + * message offset ready to encode the payload. Leaves space + * for the length field later. Upon return bufferLength will + * refer to the total available space for message, while + * bufferOffset will refer to current space used by header + * + * returns 0 if successfully encoded, -1 upon fatal error + */ +int virNetMessageEncodeHeader(virNetMessagePtr msg) +{ + XDR xdr; + int ret = -1; + unsigned int len = 0; + + msg->bufferLength = sizeof(msg->buffer); + msg->bufferOffset = 0; + + /* Format the header. */ + xdrmem_create(&xdr, + msg->buffer, + msg->bufferLength, + XDR_ENCODE); + + /* The real value is filled in shortly */ + if (!xdr_u_int(&xdr, &len)) { + virNetError(VIR_ERR_RPC, "%s", _("Unable to encode message length")); + goto cleanup; + } + + if (!xdr_virNetMessageHeader(&xdr, &msg->header)) { + virNetError(VIR_ERR_RPC, "%s", _("Unable to encode message header")); + goto cleanup; + } + + len = xdr_getpos(&xdr); + xdr_setpos(&xdr, 0); + + /* Fill in current length - may be re-written later + * if a payload is added + */ + if (!xdr_u_int(&xdr, &len)) { + virNetError(VIR_ERR_RPC, "%s", _("Unable to re-encode message length")); + goto cleanup; + } + + msg->bufferOffset += len; + + ret = 0; + +cleanup: + xdr_destroy(&xdr); + return ret; +} + + +int virNetMessageEncodePayload(virNetMessagePtr msg, + xdrproc_t filter, + void *data) +{ + XDR xdr; + unsigned int msglen; + + /* Serialise payload of the message. This assumes that + * virNetMessageEncodeHeader has already been run, so + * just appends to that data */ + xdrmem_create(&xdr, msg->buffer + msg->bufferOffset, + msg->bufferLength - msg->bufferOffset, XDR_ENCODE); + + if (!(*filter)(&xdr, data)) { + virNetError(VIR_ERR_RPC, "%s", _("Unable to encode message payload")); + goto error; + } + + /* Get the length stored in buffer. */ + msg->bufferOffset += xdr_getpos(&xdr); + xdr_destroy(&xdr); + + /* Re-encode the length word. */ + VIR_DEBUG("Encode length as %zu", msg->bufferOffset); + xdrmem_create(&xdr, msg->buffer, VIR_NET_MESSAGE_HEADER_XDR_LEN, XDR_ENCODE); + msglen = msg->bufferOffset; + if (!xdr_u_int(&xdr, &msglen)) { + virNetError(VIR_ERR_RPC, "%s", _("Unable to encode message length")); + goto error; + } + xdr_destroy(&xdr); + + msg->bufferLength = msg->bufferOffset; + msg->bufferOffset = 0; + return 0; + +error: + xdr_destroy(&xdr); + return -1; +} + + +int virNetMessageDecodePayload(virNetMessagePtr msg, + xdrproc_t filter, + void *data) +{ + XDR xdr; + + /* Deserialise payload of the message. This assumes that + * virNetMessageDecodeHeader has already been run, so + * just start from after that data */ + xdrmem_create(&xdr, msg->buffer + msg->bufferOffset, + msg->bufferLength - msg->bufferOffset, XDR_DECODE); + + if (!(*filter)(&xdr, data)) { + virNetError(VIR_ERR_RPC, "%s", _("Unable to decode message payload")); + goto error; + } + + /* Get the length stored in buffer. */ + msg->bufferLength += xdr_getpos(&xdr); + xdr_destroy(&xdr); + return 0; + +error: + xdr_destroy(&xdr); + return -1; +} + + +int virNetMessageEncodePayloadRaw(virNetMessagePtr msg, + const char *data, + size_t len) +{ + XDR xdr; + unsigned int msglen; + + if ((msg->bufferLength - msg->bufferOffset) < len) { + virNetError(VIR_ERR_RPC, + _("Stream data too long to send (%zu bytes needed, %zu bytes available)"), + len, (msg->bufferLength - msg->bufferOffset)); + return -1; + } + + memcpy(msg->buffer + msg->bufferOffset, data, len); + msg->bufferOffset += len; + + /* Re-encode the length word. */ + VIR_DEBUG("Encode length as %zu", msg->bufferOffset); + xdrmem_create(&xdr, msg->buffer, VIR_NET_MESSAGE_HEADER_XDR_LEN, XDR_ENCODE); + msglen = msg->bufferOffset; + if (!xdr_u_int(&xdr, &msglen)) { + virNetError(VIR_ERR_RPC, "%s", _("Unable to encode message length")); + goto error; + } + xdr_destroy(&xdr); + + msg->bufferLength = msg->bufferOffset; + msg->bufferOffset = 0; + return 0; + +error: + xdr_destroy(&xdr); + return -1; +} + + +void virNetMessageSaveError(virNetMessageErrorPtr rerr) +{ + /* This func may be called several times & the first + * error is the one we want because we don't want + * cleanup code overwriting the first one. + */ + if (rerr->code != VIR_ERR_OK) + return; + + virErrorPtr verr = virGetLastError(); + if (verr) { + rerr->code = verr->code; + rerr->domain = verr->domain; + rerr->message = verr->message ? malloc(sizeof(char*)) : NULL; + if (rerr->message) *rerr->message = strdup(verr->message); + rerr->level = verr->level; + rerr->str1 = verr->str1 ? malloc(sizeof(char*)) : NULL; + if (rerr->str1) *rerr->str1 = strdup(verr->str1); + rerr->str2 = verr->str2 ? malloc(sizeof(char*)) : NULL; + if (rerr->str2) *rerr->str2 = strdup(verr->str2); + rerr->str3 = verr->str3 ? malloc(sizeof(char*)) : NULL; + if (rerr->str3) *rerr->str3 = strdup(verr->str3); + rerr->int1 = verr->int1; + rerr->int2 = verr->int2; + } else { + rerr->code = VIR_ERR_INTERNAL_ERROR; + rerr->domain = VIR_FROM_RPC; + rerr->message = malloc(sizeof(char*)); + if (rerr->message) *rerr->message = strdup(_("Library function returned error but did not set virError")); + rerr->level = VIR_ERR_ERROR; + } +} diff --git a/src/rpc/virnetmessage.h b/src/rpc/virnetmessage.h new file mode 100644 index 0000000..9a92c0b --- /dev/null +++ b/src/rpc/virnetmessage.h @@ -0,0 +1,70 @@ +/* + * virnetmessage.h: basic RPC message encoding/decoding + * + * Copyright (C) 2010 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + */ + +#ifndef __VIR_NET_MESSAGE_H__ +# define __VIR_NET_MESSAGE_H__ + +# include <stdbool.h> + +# include "virnetprotocol.h" + +typedef struct virNetMessageHeader *virNetMessageHeaderPtr; +typedef struct virNetMessageError *virNetMessageErrorPtr; + +typedef struct _virNetMessage virNetMessage; +typedef virNetMessage *virNetMessagePtr; + +struct _virNetMessage { + char buffer[VIR_NET_MESSAGE_MAX + VIR_NET_MESSAGE_LEN_MAX]; + size_t bufferLength; + size_t bufferOffset; + + virNetMessageHeader header; + + virNetMessagePtr next; +}; + + +virNetMessagePtr virNetMessageNew(void); + +void virNetMessageFree(virNetMessagePtr msg); + +virNetMessagePtr virNetMessageQueueServe(virNetMessagePtr *queue); +void virNetMessageQueuePush(virNetMessagePtr *queue, + virNetMessagePtr msg); + +int virNetMessageEncodeHeader(virNetMessagePtr msg); +int virNetMessageDecodeLength(virNetMessagePtr msg); +int virNetMessageDecodeHeader(virNetMessagePtr msg); + +int virNetMessageEncodePayload(virNetMessagePtr msg, + xdrproc_t filter, + void *data); +int virNetMessageDecodePayload(virNetMessagePtr msg, + xdrproc_t filter, + void *data); + +int virNetMessageEncodePayloadRaw(virNetMessagePtr msg, + const char *buf, + size_t len); + +void virNetMessageSaveError(virNetMessageErrorPtr rerr); + +#endif /* __VIR_NET_MESSAGE_H__ */ -- 1.7.4

On 03/15/2011 11:51 AM, Daniel P. Berrange wrote:
This provides a new struct that contains a buffer for the RPC message header+payload, as well as a decoded copy of the message header. There is an API for applying a XDR encoding & decoding of the message headers and payloads. There are also APIs for maintaining a simple FIFO queue of message instances.
Expected usage scenarios are:
To send a message
msg = virNetMessageNew()
...fill in msg->header fields.. virNetMessageEncodeHeader(msg) ...loook at msg->header fields to determine payload filter virNetMessageEncodePayload(msg, xdrfilter, data) ...send msg->bufferLength worth of data from buffer
To receive a message
msg = virNetMessageNew() ...read VIR_NET_MESSAGE_LEN_MAX of data into buffer virNetMessageDecodeLength(msg) ...read msg->bufferLength-msg->bufferOffset of data into buffer virNetMessageDecodeHeader(msg) ...look at msg->header fields to determine payload filter virNetMessageDecodePayload(msg, xdrfilter, data) ...run payload processor
+++ b/src/Makefile.am @@ -1202,6 +1202,7 @@ EXTRA_DIST += $(LIBVIRT_QEMU_SYMBOL_FILE) noinst_LTLIBRARIES += libvirt-net-rpc.la
libvirt_net_rpc_la_SOURCES = \ + rpc/virnetmessage.h rpc/virnetmessage.c \ rpc/virnetprotocol.h rpc/virnetprotocol.c
Huh? There's no src/rpc in current libvirt.git. This looks like 2/15 in the v1 posting. Did you forget to submit the original 1/15 as a prerequisite patch? http://www.redhat.com/archives/libvir-list/2010-December/msg00617.html
libvirt_net_rpc_la_CFLAGS = \ $(AM_CFLAGS) diff --git a/src/rpc/virnetmessage.c b/src/rpc/virnetmessage.c new file mode 100644 index 0000000..4c226d2 --- /dev/null +++ b/src/rpc/virnetmessage.c @@ -0,0 +1,365 @@ +/* + * virnetmessage.h: basic RPC message encoding/decoding + * + * Copyright (C) 2010 Red Hat, Inc.
Welcome to 2011. :)
+/* + * @msg: the outgoing message, whose header to encode + * + * Encodes the length word and header of the message, setting the
Hmm, you still missed my spacing comment from http://www.redhat.com/archives/libvir-list/2010-December/msg00657.html s/the message/the message/
diff --git a/src/rpc/virnetmessage.h b/src/rpc/virnetmessage.h new file mode 100644 index 0000000..9a92c0b --- /dev/null +++ b/src/rpc/virnetmessage.h @@ -0,0 +1,70 @@ +/* + * virnetmessage.h: basic RPC message encoding/decoding + * + * Copyright (C) 2010 Red Hat, Inc.
2011
+ +struct _virNetMessage { + char buffer[VIR_NET_MESSAGE_MAX + VIR_NET_MESSAGE_LEN_MAX];
Is it worth a comment warning that this struct should never be stack-allocated?
+void virNetMessageFree(virNetMessagePtr msg);
cfg.mk should list this as a free-like function.
+ +virNetMessagePtr virNetMessageQueueServe(virNetMessagePtr *queue);
It's worth adding attributes: ATTRIBUTE_NONNULL(1)
+void virNetMessageQueuePush(virNetMessagePtr *queue, + virNetMessagePtr msg); ATTRIBUTE_NONNULL(1) ATTRIBUTE_NONNULL(2)
+ +int virNetMessageEncodeHeader(virNetMessagePtr msg); ATTRIBUTE_NONNULL(1) ATTRIBUTE_RETURN_CHECK
+int virNetMessageDecodeLength(virNetMessagePtr msg); ATTRIBUTE_NONNULL(1) ATTRIBUTE_RETURN_CHECK
+int virNetMessageDecodeHeader(virNetMessagePtr msg); ATTRIBUTE_NONNULL(1) ATTRIBUTE_RETURN_CHECK
+ +int virNetMessageEncodePayload(virNetMessagePtr msg, + xdrproc_t filter, + void *data); ATTRIBUTE_NONNULL(1) ATTRIBUTE_NONNULL(2) ATTRIBUTE_RETURN_CHECK
+int virNetMessageDecodePayload(virNetMessagePtr msg, + xdrproc_t filter, + void *data); ATTRIBUTE_NONNULL(1) ATTRIBUTE_NONNULL(2) ATTRIBUTE_RETURN_CHECK
+ +int virNetMessageEncodePayloadRaw(virNetMessagePtr msg, + const char *buf, + size_t len); ATTRIBUTE_NONNULL(1) ATTRIBUTE_NONNULL(2) ATTRIBUTE_RETURN_CHECK
+ +void virNetMessageSaveError(virNetMessageErrorPtr rerr); ATTRIBUTE_NONNULL(1)
-- Eric Blake eblake@redhat.com +1-801-349-2682 Libvirt virtualization library http://libvirt.org

On Tue, Mar 15, 2011 at 01:34:53PM -0600, Eric Blake wrote:
On 03/15/2011 11:51 AM, Daniel P. Berrange wrote:
This provides a new struct that contains a buffer for the RPC message header+payload, as well as a decoded copy of the message header. There is an API for applying a XDR encoding & decoding of the message headers and payloads. There are also APIs for maintaining a simple FIFO queue of message instances.
Expected usage scenarios are:
To send a message
msg = virNetMessageNew()
...fill in msg->header fields.. virNetMessageEncodeHeader(msg) ...loook at msg->header fields to determine payload filter virNetMessageEncodePayload(msg, xdrfilter, data) ...send msg->bufferLength worth of data from buffer
To receive a message
msg = virNetMessageNew() ...read VIR_NET_MESSAGE_LEN_MAX of data into buffer virNetMessageDecodeLength(msg) ...read msg->bufferLength-msg->bufferOffset of data into buffer virNetMessageDecodeHeader(msg) ...look at msg->header fields to determine payload filter virNetMessageDecodePayload(msg, xdrfilter, data) ...run payload processor
+++ b/src/Makefile.am @@ -1202,6 +1202,7 @@ EXTRA_DIST += $(LIBVIRT_QEMU_SYMBOL_FILE) noinst_LTLIBRARIES += libvirt-net-rpc.la
libvirt_net_rpc_la_SOURCES = \ + rpc/virnetmessage.h rpc/virnetmessage.c \ rpc/virnetprotocol.h rpc/virnetprotocol.c
Huh? There's no src/rpc in current libvirt.git. This looks like 2/15 in the v1 posting. Did you forget to submit the original 1/15 as a prerequisite patch? http://www.redhat.com/archives/libvir-list/2010-December/msg00617.html
Yes, picked the wrong hash when sending the series.
+ +struct _virNetMessage { + char buffer[VIR_NET_MESSAGE_MAX + VIR_NET_MESSAGE_LEN_MAX];
Is it worth a comment warning that this struct should never be stack-allocated?
Added
+void virNetMessageFree(virNetMessagePtr msg);
cfg.mk should list this as a free-like function.
Added, and several more in later patches
+ +virNetMessagePtr virNetMessageQueueServe(virNetMessagePtr *queue);
It's worth adding attributes:
ATTRIBUTE_NONNULL(1)
+void virNetMessageQueuePush(virNetMessagePtr *queue, + virNetMessagePtr msg); ATTRIBUTE_NONNULL(1) ATTRIBUTE_NONNULL(2)
+ +int virNetMessageEncodeHeader(virNetMessagePtr msg); ATTRIBUTE_NONNULL(1) ATTRIBUTE_RETURN_CHECK
+int virNetMessageDecodeLength(virNetMessagePtr msg); ATTRIBUTE_NONNULL(1) ATTRIBUTE_RETURN_CHECK
+int virNetMessageDecodeHeader(virNetMessagePtr msg); ATTRIBUTE_NONNULL(1) ATTRIBUTE_RETURN_CHECK
+ +int virNetMessageEncodePayload(virNetMessagePtr msg, + xdrproc_t filter, + void *data); ATTRIBUTE_NONNULL(1) ATTRIBUTE_NONNULL(2) ATTRIBUTE_RETURN_CHECK
+int virNetMessageDecodePayload(virNetMessagePtr msg, + xdrproc_t filter, + void *data); ATTRIBUTE_NONNULL(1) ATTRIBUTE_NONNULL(2) ATTRIBUTE_RETURN_CHECK
+ +int virNetMessageEncodePayloadRaw(virNetMessagePtr msg, + const char *buf, + size_t len); ATTRIBUTE_NONNULL(1) ATTRIBUTE_NONNULL(2) ATTRIBUTE_RETURN_CHECK
+ +void virNetMessageSaveError(virNetMessageErrorPtr rerr); ATTRIBUTE_NONNULL(1)
Added all those too. Daniel -- |: http://berrange.com -o- http://www.flickr.com/photos/dberrange/ :| |: http://libvirt.org -o- http://virt-manager.org :| |: http://autobuild.org -o- http://search.cpan.org/~danberr/ :| |: http://entangle-photo.org -o- http://live.gnome.org/gtk-vnc :|

Introduces a simple wrapper around the raw POSIX sockets APIs and name resolution APIs. Allows for easy creation of client and server sockets with correct usage of name resolution APIs for protocol agnostic socket setup. It can listen for UNIX and TCP stream sockets. It can connect to UNIX, TCP streams directly, or indirectly to UNIX sockets via an SSH tunnel or external command * src/Makefile.am: Add to libvirt-net-rpc.la * src/rpc/virnetsocket.c, src/rpc/virnetsocket.h: Generic sockets APIs --- .x-sc_avoid_write | 1 + configure.ac | 2 +- po/POTFILES.in | 1 + src/Makefile.am | 3 +- src/rpc/virnetsocket.c | 813 ++++++++++++++++++++++++++++++++++++++++++++++++ src/rpc/virnetsocket.h | 107 +++++++ 6 files changed, 925 insertions(+), 2 deletions(-) create mode 100644 src/rpc/virnetsocket.c create mode 100644 src/rpc/virnetsocket.h diff --git a/.x-sc_avoid_write b/.x-sc_avoid_write index 0784984..5565713 100644 --- a/.x-sc_avoid_write +++ b/.x-sc_avoid_write @@ -1,6 +1,7 @@ ^src/libvirt\.c$ ^src/fdstream\.c$ ^src/qemu/qemu_monitor\.c$ +^src/rpc/virnetsocket\.c$ ^src/util/command\.c$ ^src/util/util\.c$ ^src/xen/xend_internal\.c$ diff --git a/configure.ac b/configure.ac index e2b2b24..49403dd 100644 --- a/configure.ac +++ b/configure.ac @@ -134,7 +134,7 @@ LIBS=$old_libs dnl Availability of various common headers (non-fatal if missing). AC_CHECK_HEADERS([pwd.h paths.h regex.h sys/syslimits.h sys/un.h \ sys/poll.h syslog.h mntent.h net/ethernet.h linux/magic.h \ - sys/un.h sys/syscall.h]) + sys/un.h sys/syscall.h netinet/tcp.h]) AC_CHECK_LIB([intl],[gettext],[]) diff --git a/po/POTFILES.in b/po/POTFILES.in index 65f4fc3..2ce3bba 100644 --- a/po/POTFILES.in +++ b/po/POTFILES.in @@ -65,6 +65,7 @@ src/qemu/qemu_monitor_text.c src/qemu/qemu_process.c src/remote/remote_driver.c src/rpc/virnetmessage.c +src/rpc/virnetsocket.c src/secret/secret_driver.c src/security/security_apparmor.c src/security/security_dac.c diff --git a/src/Makefile.am b/src/Makefile.am index 7b9cdd3..332b6ac 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -1203,7 +1203,8 @@ noinst_LTLIBRARIES += libvirt-net-rpc.la libvirt_net_rpc_la_SOURCES = \ rpc/virnetmessage.h rpc/virnetmessage.c \ - rpc/virnetprotocol.h rpc/virnetprotocol.c + rpc/virnetprotocol.h rpc/virnetprotocol.c \ + rpc/virnetsocket.h rpc/virnetsocket.c libvirt_net_rpc_la_CFLAGS = \ $(AM_CFLAGS) libvirt_net_rpc_la_LDFLAGS = \ diff --git a/src/rpc/virnetsocket.c b/src/rpc/virnetsocket.c new file mode 100644 index 0000000..a0eb431 --- /dev/null +++ b/src/rpc/virnetsocket.c @@ -0,0 +1,813 @@ +/* + * virnetsocket.h: generic network socket handling + * + * Copyright (C) 2006-2010 Red Hat, Inc. + * Copyright (C) 2006 Daniel P. Berrange + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#include <config.h> + +#include <sys/stat.h> +#include <sys/socket.h> +#include <unistd.h> +#include <sys/wait.h> + +#include "virnetsocket.h" +#include "util.h" +#include "memory.h" +#include "virterror_internal.h" +#include "logging.h" +#include "files.h" +#include "event.h" + +#define VIR_FROM_THIS VIR_FROM_RPC + +#define virNetError(code, ...) \ + virReportErrorHelper(NULL, VIR_FROM_RPC, code, __FILE__, \ + __FUNCTION__, __LINE__, __VA_ARGS__) + + +struct _virNetSocket { + int fd; + int watch; + pid_t pid; + int errfd; + bool client; + virNetSocketIOFunc func; + void *opaque; + virSocketAddr localAddr; + virSocketAddr remoteAddr; + char *localAddrStr; + char *remoteAddrStr; +}; + + +#ifndef WIN32 +static int virNetSocketForkDaemon(const char *binary) +{ + int ret; + virCommandPtr cmd = virCommandNewArgList(binary, + "--timeout=30", + NULL); + + virCommandAddEnvPassCommon(cmd); + virCommandClearCaps(cmd); + virCommandDaemonize(cmd); + ret = virCommandRun(cmd, NULL); + virCommandFree(cmd); + return ret; +} +#endif + + +static virNetSocketPtr virNetSocketNew(virSocketAddrPtr localAddr, + virSocketAddrPtr remoteAddr, + bool isClient, + int fd, int errfd, pid_t pid) +{ + virNetSocketPtr sock; + int no_slow_start = 1; + + VIR_DEBUG("localAddr=%p remoteAddr=%p fd=%d errfd=%d pid=%d", + localAddr, remoteAddr, + fd, errfd, pid); + + if (virSetCloseExec(fd) < 0 || + virSetNonBlock(fd) < 0) + return NULL; + + if (VIR_ALLOC(sock) < 0) { + virReportOOMError(); + return NULL; + } + + if (localAddr) + sock->localAddr = *localAddr; + if (remoteAddr) + sock->remoteAddr = *remoteAddr; + sock->fd = fd; + sock->errfd = errfd; + sock->pid = pid; + + /* Disable nagle for TCP sockets */ + if (sock->localAddr.data.sa.sa_family == AF_INET || + sock->localAddr.data.sa.sa_family == AF_INET6) + setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, + &no_slow_start, + sizeof(no_slow_start)); + + + if (localAddr && + !(sock->localAddrStr = virSocketFormatAddrFull(localAddr, true, ";"))) + goto error; + + if (remoteAddr && + !(sock->remoteAddrStr = virSocketFormatAddrFull(remoteAddr, true, ";"))) + goto error; + + sock->client = isClient; + + VIR_DEBUG("sock=%p localAddrStr=%s remoteAddrStr=%s", + sock, NULLSTR(sock->localAddrStr), NULLSTR(sock->remoteAddrStr)); + + return sock; + +error: + sock->fd = sock->errfd = -1; /* Caller owns fd/errfd on failure */ + virNetSocketFree(sock); + return NULL; +} + + +int virNetSocketNewListenTCP(const char *nodename, + const char *service, + virNetSocketPtr **retsocks, + size_t *nretsocks) +{ + virNetSocketPtr *socks = NULL; + size_t nsocks = 0; + struct addrinfo *ai = NULL; + struct addrinfo hints; + int fd = -1; + int i; + + *retsocks = NULL; + *nretsocks = 0; + + memset (&hints, 0, sizeof hints); + hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG; + hints.ai_socktype = SOCK_STREAM; + + int e = getaddrinfo(nodename, service, &hints, &ai); + if (e != 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Unable to resolve address '%s' service '%s': %s"), + nodename, service, gai_strerror(e)); + return -1; + } + + struct addrinfo *runp = ai; + while (runp) { + virSocketAddr addr; + + memset(&addr, 0, sizeof(addr)); + + if ((fd = socket(runp->ai_family, runp->ai_socktype, + runp->ai_protocol)) < 0) { + virReportSystemError(errno, "%s", _("Unable to create socket")); + goto error; + } + + int opt = 1; + setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof opt); + +#ifdef IPV6_V6ONLY + if (runp->ai_family == PF_INET6) { + int on = 1; + /* + * Normally on Linux an INET6 socket will bind to the INET4 + * address too. If getaddrinfo returns results with INET4 + * first though, this will result in INET6 binding failing. + * We can trivially cope with multiple server sockets, so + * we force it to only listen on IPv6 + */ + setsockopt(fd, IPPROTO_IPV6,IPV6_V6ONLY, + (void*)&on, sizeof on); + } +#endif + + if (bind(fd, runp->ai_addr, runp->ai_addrlen) < 0) { + if (errno != EADDRINUSE) { + virReportSystemError(errno, "%s", _("Unable to bind to port")); + goto error; + } + VIR_FORCE_CLOSE(fd); + continue; + } + + addr.len = sizeof(addr.data); + if (getsockname(fd, &addr.data.sa, &addr.len) < 0) { + virReportSystemError(errno, "%s", _("Unable to get local socket name")); + goto error; + } + + VIR_DEBUG("%p f=%d f=%d", &addr, runp->ai_family, addr.data.sa.sa_family); + + if (VIR_EXPAND_N(socks, nsocks, 1) < 0) { + virReportOOMError(); + goto error; + } + + if (!(socks[nsocks-1] = virNetSocketNew(&addr, NULL, false, fd, -1, 0))) + goto error; + runp = runp->ai_next; + fd = -1; + } + + freeaddrinfo(ai); + + *retsocks = socks; + *nretsocks = nsocks; + return 0; + +error: + freeaddrinfo(ai); + for (i = 0 ; i < nsocks ; i++) + virNetSocketFree(socks[i]); + VIR_FREE(socks); + freeaddrinfo(ai); + VIR_FORCE_CLOSE(fd); + return -1; +} + + +#if HAVE_SYS_UN_H +int virNetSocketNewListenUNIX(const char *path, + mode_t mask, + gid_t grp, + virNetSocketPtr *retsock) +{ + virSocketAddr addr; + mode_t oldmask; + int fd; + + *retsock = NULL; + + memset(&addr, 0, sizeof(addr)); + + addr.len = sizeof(addr.data.un); + + if ((fd = socket(PF_UNIX, SOCK_STREAM, 0)) < 0) { + virReportSystemError(errno, "%s", _("Failed to create socket")); + goto error; + } + + addr.data.un.sun_family = AF_UNIX; + if (virStrcpyStatic(addr.data.un.sun_path, path) == NULL) { + virReportSystemError(ENOMEM, _("Path %s too long for unix socket"), path); + goto error; + } + if (addr.data.un.sun_path[0] == '@') + addr.data.un.sun_path[0] = '\0'; + else + unlink(addr.data.un.sun_path); + + oldmask = umask(~mask); + + if (bind(fd, &addr.data.sa, addr.len) < 0) { + virReportSystemError(errno, + _("Failed to bind socket to '%s'"), + path); + goto error; + } + umask(oldmask); + + /* chown() doesn't work for abstract sockets but we use them only + * if libvirtd runs unprivileged + */ + if (grp != 0 && chown(path, -1, grp)) { + virReportSystemError(errno, + _("Failed to change group ID of '%s' to %d"), + path, grp); + goto error; + } + + if (!(*retsock = virNetSocketNew(&addr, NULL, false, fd, -1, 0))) + goto error; + + return 0; + +error: + if (path[0] != '@') + unlink(path); + VIR_FORCE_CLOSE(fd); + return -1; +} +#else +int virNetSocketNewListenUNIX(const char *path ATTRIBUTE_UNUSED, + mode_t mask ATTRIBUTE_UNUSED, + gid_t grp ATTRIBUTE_UNUSED, + virNetSocketPtr *retsock ATTRIBUTE_UNUSED) +{ + virReportSystemError(ENOSYS, "%s", + _("UNIX sockets are not supported on this platform")); + return -1; +} +#endif + + +int virNetSocketNewConnectTCP(const char *nodename, + const char *service, + virNetSocketPtr *retsock) +{ + struct addrinfo *ai = NULL; + struct addrinfo hints; + int fd = -1; + virSocketAddr localAddr; + virSocketAddr remoteAddr; + struct addrinfo *runp; + int savedErrno = ENOENT; + + *retsock = NULL; + + memset(&localAddr, 0, sizeof(localAddr)); + memset(&remoteAddr, 0, sizeof(remoteAddr)); + + memset(&hints, 0, sizeof hints); + hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG; + hints.ai_socktype = SOCK_STREAM; + + int e = getaddrinfo(nodename, service, &hints, &ai); + if (e != 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Unable to resolve address '%s' service '%s': %s"), + nodename, service, gai_strerror (e)); + return -1; + } + + runp = ai; + while (runp) { + int opt = 1; + + if ((fd = socket(runp->ai_family, runp->ai_socktype, + runp->ai_protocol)) < 0) { + virReportSystemError(errno, "%s", _("Unable to create socket")); + goto error; + } + + setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof opt); + + if (connect(fd, runp->ai_addr, runp->ai_addrlen) >= 0) + break; + + savedErrno = errno; + VIR_FORCE_CLOSE(fd); + runp = runp->ai_next; + } + + if (fd == -1) { + virReportSystemError(savedErrno, + _("unable to connect to server at '%s:%s'"), + nodename, service); + goto error; + } + + localAddr.len = sizeof(localAddr.data); + if (getsockname(fd, &localAddr.data.sa, &localAddr.len) < 0) { + virReportSystemError(errno, "%s", _("Unable to get local socket name")); + goto error; + } + + remoteAddr.len = sizeof(remoteAddr.data); + if (getpeername(fd, &remoteAddr.data.sa, &remoteAddr.len) < 0) { + virReportSystemError(errno, "%s", _("Unable to get remote socket name")); + goto error; + } + + if (!(*retsock = virNetSocketNew(&localAddr, &remoteAddr, true, fd, -1, 0))) + goto error; + + freeaddrinfo(ai); + + return 0; + +error: + freeaddrinfo(ai); + VIR_FORCE_CLOSE(fd); + return -1; +} + + +#if HAVE_SYS_UN_H +int virNetSocketNewConnectUNIX(const char *path, + bool spawnDaemon, + const char *binary, + virNetSocketPtr *retsock) +{ + virSocketAddr localAddr; + virSocketAddr remoteAddr; + int fd; + int retries = 0; + + memset(&localAddr, 0, sizeof(localAddr)); + memset(&remoteAddr, 0, sizeof(remoteAddr)); + + remoteAddr.len = sizeof(remoteAddr.data.un); + + if ((fd = socket(PF_UNIX, SOCK_STREAM, 0)) < 0) { + virReportSystemError(errno, "%s", _("Failed to create socket")); + goto error; + } + + remoteAddr.data.un.sun_family = AF_UNIX; + if (virStrcpyStatic(remoteAddr.data.un.sun_path, path) == NULL) { + virReportSystemError(ENOMEM, _("Path %s too long for unix socket"), path); + goto error; + } + if (remoteAddr.data.un.sun_path[0] == '@') + remoteAddr.data.un.sun_path[0] = '\0'; + +retry: + if (connect(fd, &remoteAddr.data.sa, remoteAddr.len) < 0) { + if (errno == ECONNREFUSED && spawnDaemon && retries < 20) { + if (retries == 0 && + virNetSocketForkDaemon(binary) < 0) + goto error; + + retries++; + usleep(1000 * 100 * retries); + goto retry; + } + + virReportSystemError(errno, + _("Failed to connect socket to '%s'"), + path); + goto error; + } + + localAddr.len = sizeof(localAddr.data); + if (getsockname(fd, &localAddr.data.sa, &localAddr.len) < 0) { + virReportSystemError(errno, "%s", _("Unable to get local socket name")); + goto error; + } + + if (!(*retsock = virNetSocketNew(&localAddr, &remoteAddr, true, fd, -1, 0))) + goto error; + + return 0; + +error: + VIR_FORCE_CLOSE(fd); + return -1; +} +#else +int virNetSocketNewConnectUNIX(const char *path ATTRIBUTE_UNUSED, + bool spawnDaemon ATTRIBUTE_UNUSED, + const char *binary ATTRIBUTE_UNUSED, + virNetSocketPtr *retsock ATTRIBUTE_UNUSED) +{ + virReportSystemError(ENOSYS, "%s", + _("UNIX sockets are not supported on this platform")); + return -1; +} +#endif + + +#ifndef WIN32 +int virNetSocketNewConnectCommand(virCommandPtr cmd, + virNetSocketPtr *retsock) +{ + pid_t pid = 0; + int sv[2]; + int errfd[2]; + + *retsock = NULL; + + /* Fork off the external process. Use socketpair to create a private + * (unnamed) Unix domain socket to the child process so we don't have + * to faff around with two file descriptors (a la 'pipe(2)'). + */ + if (socketpair(PF_UNIX, SOCK_STREAM, 0, sv) < 0) { + virReportSystemError(errno, "%s", + _("unable to create socket pair")); + goto error; + } + + if (pipe(errfd) < 0) { + virReportSystemError(errno, "%s", + _("unable to create socket pair")); + goto error; + } + + virCommandSetInputFD(cmd, sv[1]); + virCommandSetOutputFD(cmd, &sv[1]); + virCommandSetErrorFD(cmd, &errfd[1]); + + if (virCommandRunAsync(cmd, &pid) < 0) + goto error; + + /* Parent continues here. */ + VIR_FORCE_CLOSE(sv[1]); + VIR_FORCE_CLOSE(errfd[1]); + + if (!(*retsock = virNetSocketNew(NULL, NULL, true, sv[0], errfd[0], pid))) + goto error; + + virCommandFree(cmd); + + return 0; + +error: + VIR_FORCE_CLOSE(sv[0]); + VIR_FORCE_CLOSE(sv[1]); + VIR_FORCE_CLOSE(errfd[0]); + VIR_FORCE_CLOSE(errfd[1]); + + if (pid > 0) { + kill(pid, SIGTERM); + if (virCommandWait(cmd, NULL) < 0) { + kill(pid, SIGKILL); + if (virCommandWait(cmd, NULL) < 0) { + VIR_WARN("Unable to wait for command %d", pid); + } + } + } + + virCommandFree(cmd); + + return -1; +} +#else +int virNetSocketNewConnectCommand(virCommandPtr cmd ATTRIBUTE_UNUSED, + virNetSocketPtr *retsock ATTRIBUTE_UNUSED) +{ + virReportSystemError(errno, "%s", + _("Tunnelling sockets not supported on this platform")); + return -1; +} +#endif + +int virNetSocketNewConnectSSH(const char *nodename, + const char *service, + const char *binary, + const char *username, + bool noTTY, + const char *netcat, + const char *path, + virNetSocketPtr *retsock) +{ + virCommandPtr cmd; + *retsock = NULL; + + cmd = virCommandNew(binary ? binary : "ssh"); + virCommandAddEnvPassCommon(cmd); + virCommandAddEnvPass(cmd, "SSH_AUTH_SOCK"); + virCommandAddEnvPass(cmd, "SSH_ASKPASS"); + virCommandClearCaps(cmd); + + if (service) + virCommandAddArgList(cmd, "-p", service, NULL); + if (username) + virCommandAddArgList(cmd, "-l", username, NULL); + if (noTTY) + virCommandAddArgList(cmd, "-T", "-o", "BatchMode=yes", + "-e", "none", NULL); + virCommandAddArgList(cmd, nodename, + netcat ? netcat : "nc", + "-U", path, NULL); + + return virNetSocketNewConnectCommand(cmd, retsock); +} + + +int virNetSocketNewConnectExternal(const char **cmdargv, + virNetSocketPtr *retsock) +{ + virCommandPtr cmd; + + *retsock = NULL; + + cmd = virCommandNewArgs(cmdargv); + virCommandAddEnvPassCommon(cmd); + virCommandClearCaps(cmd); + + return virNetSocketNewConnectCommand(cmd, retsock); +} + + +void virNetSocketFree(virNetSocketPtr sock) +{ + if (!sock) + return; + + VIR_DEBUG("sock=%p fd=%d", sock, sock->fd); + if (sock->watch > 0) { + virEventRemoveHandle(sock->watch); + sock->watch = -1; + } + +#ifdef HAVE_SYS_UN_H + /* If a server socket, then unlink UNIX path */ + if (!sock->client && + sock->localAddr.data.sa.sa_family == AF_UNIX && + sock->localAddr.data.un.sun_path[0] != '\0') + unlink(sock->localAddr.data.un.sun_path); +#endif + + VIR_FORCE_CLOSE(sock->fd); + VIR_FORCE_CLOSE(sock->errfd); + +#ifndef WIN32 + if (sock->pid > 0) { + pid_t reap; + kill(sock->pid, SIGTERM); + do { +retry: + reap = waitpid(sock->pid, NULL, 0); + if (reap == -1 && errno == EINTR) + goto retry; + } while (reap != -1 && reap != sock->pid); + } +#endif + + VIR_FREE(sock->localAddrStr); + VIR_FREE(sock->remoteAddrStr); + + VIR_FREE(sock); +} + + +int virNetSocketGetFD(virNetSocketPtr sock) +{ + return sock->fd; +} + + +bool virNetSocketIsLocal(virNetSocketPtr sock) +{ + if (sock->localAddr.data.sa.sa_family == AF_UNIX) + return true; + return false; +} + + +#ifdef SO_PEERCRED +int virNetSocketGetLocalIdentity(virNetSocketPtr sock, + uid_t *uid, + pid_t *pid) +{ + struct ucred cr; + unsigned int cr_len = sizeof (cr); + + if (getsockopt(sock->fd, SOL_SOCKET, SO_PEERCRED, &cr, &cr_len) < 0) { + virReportSystemError(errno, "%s", + _("Failed to get client socket identity")); + return -1; + } + + *pid = cr.pid; + *uid = cr.uid; + return 0; +} +# else +int virNetSocketGetLocalIdentity(virNetSocketPtr sock ATTRIBUTE_UNUSED, + uid_t *uid ATTRIBUTE_UNUSED, + pid_t *pid ATTRIBUTE_UNUSED) +{ + /* XXX Many more OS support UNIX socket credentials we could port to. See dbus ....*/ + virReportSystemError(ENOSYS, "%s", + _("Client socket identity not available")); + return -1; +} +# endif + + +int virNetSocketSetBlocking(virNetSocketPtr sock, + bool blocking) +{ + return virSetBlocking(sock->fd, blocking); +} + + +const char *virNetSocketLocalAddrString(virNetSocketPtr sock) +{ + return sock->localAddrStr; +} + +const char *virNetSocketRemoteAddrString(virNetSocketPtr sock) +{ + return sock->remoteAddrStr; +} + +ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len) +{ + return read(sock->fd, buf, len); +} + +ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len) +{ + return write(sock->fd, buf, len); +} + + +int virNetSocketListen(virNetSocketPtr sock) +{ + if (listen(sock->fd, 30) < 0) { + virReportSystemError(errno, "%s", _("Unable to listen on socket")); + return -1; + } + return 0; +} + +int virNetSocketAccept(virNetSocketPtr sock, virNetSocketPtr *clientsock) +{ + int fd; + virSocketAddr localAddr; + virSocketAddr remoteAddr; + + *clientsock = NULL; + + memset(&localAddr, 0, sizeof(localAddr)); + memset(&remoteAddr, 0, sizeof(remoteAddr)); + + remoteAddr.len = sizeof(remoteAddr.data.stor); + if ((fd = accept(sock->fd, &remoteAddr.data.sa, &remoteAddr.len)) < 0) { + if (errno == ECONNABORTED || + errno == EAGAIN) + return 0; + + virReportSystemError(errno, "%s", + _("Unable to accept client")); + return -1; + } + + localAddr.len = sizeof(localAddr.data); + if (getsockname(fd, &localAddr.data.sa, &localAddr.len) < 0) { + virReportSystemError(errno, "%s", _("Unable to get local socket name")); + VIR_FORCE_CLOSE(fd); + return -1; + } + + if (!(*clientsock = virNetSocketNew(&localAddr, + &remoteAddr, + true, + fd, -1, 0))) { + VIR_FORCE_CLOSE(fd); + return -1; + } + + return 0; +} + + +static void virNetSocketEventHandle(int fd ATTRIBUTE_UNUSED, + int watch ATTRIBUTE_UNUSED, + int events, + void *opaque) +{ + virNetSocketPtr sock = opaque; + + sock->func(sock, events, sock->opaque); +} + +int virNetSocketAddIOCallback(virNetSocketPtr sock, + int events, + virNetSocketIOFunc func, + void *opaque) +{ + if (sock->watch > 0) { + VIR_DEBUG("Watch already registered on socket %p", sock); + return -1; + } + + if ((sock->watch = virEventAddHandle(sock->fd, + events, + virNetSocketEventHandle, + sock, + NULL)) < 0) { + VIR_WARN("Failed to register watch on socket %p", sock); + return -1; + } + sock->func = func; + sock->opaque = opaque; + + return 0; +} + +void virNetSocketUpdateIOCallback(virNetSocketPtr sock, + int events) +{ + if (sock->watch <= 0) { + VIR_DEBUG("Watch not registered on socket %p", sock); + return; + } + + virEventUpdateHandle(sock->watch, events); +} + +void virNetSocketRemoveIOCallback(virNetSocketPtr sock) +{ + if (sock->watch <= 0) { + VIR_DEBUG("Watch not registered on socket %p", sock); + return; + } + + virEventRemoveHandle(sock->watch); + sock->watch = 0; +} diff --git a/src/rpc/virnetsocket.h b/src/rpc/virnetsocket.h new file mode 100644 index 0000000..c33b2e1 --- /dev/null +++ b/src/rpc/virnetsocket.h @@ -0,0 +1,107 @@ +/* + * virnetsocket.h: generic network socket handling + * + * Copyright (C) 2006-2010 Red Hat, Inc. + * Copyright (C) 2006 Daniel P. Berrange + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#ifndef __VIR_NET_SOCKET_H__ +# define __VIR_NET_SOCKET_H__ + +# include "network.h" +# include "command.h" + +typedef struct _virNetSocket virNetSocket; +typedef virNetSocket *virNetSocketPtr; + + +typedef void (*virNetSocketIOFunc)(virNetSocketPtr sock, + int events, + void *opaque); + + +int virNetSocketNewListenTCP(const char *nodename, + const char *service, + virNetSocketPtr **addrs, + size_t *naddrs); + +int virNetSocketNewListenUNIX(const char *path, + mode_t mask, + gid_t grp, + virNetSocketPtr *addr); + +int virNetSocketNewConnectTCP(const char *nodename, + const char *service, + virNetSocketPtr *addr); + +int virNetSocketNewConnectUNIX(const char *path, + bool spawnDaemon, + const char *binary, + virNetSocketPtr *addr); + +int virNetSocketNewConnectCommand(virCommandPtr cmd, + virNetSocketPtr *retsock); + +int virNetSocketNewConnectSSH(const char *nodename, + const char *service, + const char *binary, + const char *username, + bool noTTY, + const char *netcat, + const char *path, + virNetSocketPtr *addr); + +int virNetSocketNewConnectExternal(const char **cmdargv, + virNetSocketPtr *addr); + +int virNetSocketGetFD(virNetSocketPtr sock); +bool virNetSocketIsLocal(virNetSocketPtr sock); + +int virNetSocketGetLocalIdentity(virNetSocketPtr sock, + uid_t *uid, + pid_t *pid); + +int virNetSocketSetBlocking(virNetSocketPtr sock, + bool blocking); + +ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len); +ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len); + +void virNetSocketFree(virNetSocketPtr sock); + +const char *virNetSocketLocalAddrString(virNetSocketPtr sock); +const char *virNetSocketRemoteAddrString(virNetSocketPtr sock); + +int virNetSocketListen(virNetSocketPtr sock); +int virNetSocketAccept(virNetSocketPtr sock, + virNetSocketPtr *clientsock); + +int virNetSocketAddIOCallback(virNetSocketPtr sock, + int events, + virNetSocketIOFunc func, + void *opaque); + +void virNetSocketUpdateIOCallback(virNetSocketPtr sock, + int events); + +void virNetSocketRemoveIOCallback(virNetSocketPtr sock); + + + +#endif /* __VIR_NET_SOCKET_H__ */ -- 1.7.4

On 03/15/2011 11:51 AM, Daniel P. Berrange wrote:
Introduces a simple wrapper around the raw POSIX sockets APIs and name resolution APIs. Allows for easy creation of client and server sockets with correct usage of name resolution APIs for protocol agnostic socket setup.
It can listen for UNIX and TCP stream sockets.
It can connect to UNIX, TCP streams directly, or indirectly to UNIX sockets via an SSH tunnel or external command
* src/Makefile.am: Add to libvirt-net-rpc.la * src/rpc/virnetsocket.c, src/rpc/virnetsocket.h: Generic sockets APIs --- .x-sc_avoid_write | 1 + configure.ac | 2 +- po/POTFILES.in | 1 + src/Makefile.am | 3 +- src/rpc/virnetsocket.c | 813 ++++++++++++++++++++++++++++++++++++++++++++++++ src/rpc/virnetsocket.h | 107 +++++++ 6 files changed, 925 insertions(+), 2 deletions(-) create mode 100644 src/rpc/virnetsocket.c create mode 100644 src/rpc/virnetsocket.h
Looks like most (all?) of my earlier review comments were incorporated - no more nasty double-close bugs :)
diff --git a/src/rpc/virnetsocket.c b/src/rpc/virnetsocket.c new file mode 100644 index 0000000..a0eb431 --- /dev/null +++ b/src/rpc/virnetsocket.c @@ -0,0 +1,813 @@ +/* + * virnetsocket.h: generic network socket handling + * + * Copyright (C) 2006-2010 Red Hat, Inc.
Add 2011.
+#ifndef WIN32 +static int virNetSocketForkDaemon(const char *binary) +{ + int ret; + virCommandPtr cmd = virCommandNewArgList(binary, + "--timeout=30", + NULL); + + virCommandAddEnvPassCommon(cmd); + virCommandClearCaps(cmd); + virCommandDaemonize(cmd); + ret = virCommandRun(cmd, NULL); + virCommandFree(cmd); + return ret; +} +#endif
Does this need an #else stub for mingw compilation, or is it only ever called from code already excluded on mingw?
+ + +static virNetSocketPtr virNetSocketNew(virSocketAddrPtr localAddr, + virSocketAddrPtr remoteAddr, + bool isClient, + int fd, int errfd, pid_t pid) +{ + virNetSocketPtr sock; + int no_slow_start = 1; + + VIR_DEBUG("localAddr=%p remoteAddr=%p fd=%d errfd=%d pid=%d", + localAddr, remoteAddr, + fd, errfd, pid); + + if (virSetCloseExec(fd) < 0 || + virSetNonBlock(fd) < 0) + return NULL;
No error message? The virSet* functions are intentionally silent on error, but this helper function should probably always issue an error on all failure paths....
+ + if (VIR_ALLOC(sock) < 0) { + virReportOOMError(); + return NULL;
given that it already did so on this path, and that the caller can't tell by a NULL return which error happened.
+ } + + if (localAddr) + sock->localAddr = *localAddr; + if (remoteAddr) + sock->remoteAddr = *remoteAddr; + sock->fd = fd; + sock->errfd = errfd; + sock->pid = pid; + + /* Disable nagle for TCP sockets */ + if (sock->localAddr.data.sa.sa_family == AF_INET || + sock->localAddr.data.sa.sa_family == AF_INET6) + setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, + &no_slow_start, + sizeof(no_slow_start));
We don't care if setsockopt failed?
+ +int virNetSocketNewListenTCP(const char *nodename, + const char *service, + virNetSocketPtr **retsocks, + size_t *nretsocks) +{ + virNetSocketPtr *socks = NULL; + size_t nsocks = 0; + struct addrinfo *ai = NULL; + struct addrinfo hints; + int fd = -1; + int i; + + *retsocks = NULL; + *nretsocks = 0; + + memset (&hints, 0, sizeof hints);
My earlier review comment about ' (' vs. '(' consistency in function calls still hasn't been addressed: http://www.redhat.com/archives/libvir-list/2010-December/msg00675.html
+ int opt = 1; + setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof opt);
Again, no checks for setsockopt failures?
+error: + freeaddrinfo(ai); + for (i = 0 ; i < nsocks ; i++) + virNetSocketFree(socks[i]); + VIR_FREE(socks); + freeaddrinfo(ai);
Ouch - double freeaddrinfo - bound to segfault.
+ + oldmask = umask(~mask); + + if (bind(fd, &addr.data.sa, addr.len) < 0) { + virReportSystemError(errno, + _("Failed to bind socket to '%s'"), + path); + goto error; + } + umask(oldmask);
It's a shame that umask() is process-wide. This introduces a race window to other threads. Is this a case where we need another virFileOpenAs helper method, which forks, does the umask and bind in the child, then passes the fd back to the parent? But that's a question for another day, and doesn't affect the validity of this patch.
+ + +#ifndef WIN32 +int virNetSocketNewConnectCommand(virCommandPtr cmd, + virNetSocketPtr *retsock) +{ + pid_t pid = 0; + int sv[2]; + int errfd[2]; + + *retsock = NULL; + + /* Fork off the external process. Use socketpair to create a private + * (unnamed) Unix domain socket to the child process so we don't have + * to faff around with two file descriptors (a la 'pipe(2)'). + */ + if (socketpair(PF_UNIX, SOCK_STREAM, 0, sv) < 0) { + virReportSystemError(errno, "%s", + _("unable to create socket pair")); + goto error; + } + + if (pipe(errfd) < 0) {
Should we set the parent's half of sv and errfd to cloexec?
+error: + VIR_FORCE_CLOSE(sv[0]); + VIR_FORCE_CLOSE(sv[1]); + VIR_FORCE_CLOSE(errfd[0]); + VIR_FORCE_CLOSE(errfd[1]); + + if (pid > 0) { + kill(pid, SIGTERM); + if (virCommandWait(cmd, NULL) < 0) { + kill(pid, SIGKILL); + if (virCommandWait(cmd, NULL) < 0) { + VIR_WARN("Unable to wait for command %d", pid);
Hmm, I really ought to write virCommandKill to make this idiom easier (my virFileOpenAs patch can also use it).
+int virNetSocketAccept(virNetSocketPtr sock, virNetSocketPtr *clientsock) +{ + int fd; + virSocketAddr localAddr; + virSocketAddr remoteAddr; + + *clientsock = NULL; + + memset(&localAddr, 0, sizeof(localAddr)); + memset(&remoteAddr, 0, sizeof(remoteAddr)); + + remoteAddr.len = sizeof(remoteAddr.data.stor); + if ((fd = accept(sock->fd, &remoteAddr.data.sa, &remoteAddr.len)) < 0) { + if (errno == ECONNABORTED || + errno == EAGAIN) + return 0;
As written, this function returns 0 for both retry and success, and -1 for all other failure; it is up to the caller to check whether *clientsock is NULL to know if a retry is needed. Should it return 0 for success and 1 for retry, to make it easier to use?
diff --git a/src/rpc/virnetsocket.h b/src/rpc/virnetsocket.h new file mode 100644 index 0000000..c33b2e1 --- /dev/null +++ b/src/rpc/virnetsocket.h @@ -0,0 +1,107 @@ +/* + * virnetsocket.h: generic network socket handling + * + * Copyright (C) 2006-2010 Red Hat, Inc.
2011 No change to src/libvirt_private.syms to list all these new functions? -- Eric Blake eblake@redhat.com +1-801-349-2682 Libvirt virtualization library http://libvirt.org

On Tue, Mar 15, 2011 at 03:23:38PM -0600, Eric Blake wrote:
On 03/15/2011 11:51 AM, Daniel P. Berrange wrote:
+#ifndef WIN32 +static int virNetSocketForkDaemon(const char *binary) +{ + int ret; + virCommandPtr cmd = virCommandNewArgList(binary, + "--timeout=30", + NULL); + + virCommandAddEnvPassCommon(cmd); + virCommandClearCaps(cmd); + virCommandDaemonize(cmd); + ret = virCommandRun(cmd, NULL); + virCommandFree(cmd); + return ret; +} +#endif
Does this need an #else stub for mingw compilation, or is it only ever called from code already excluded on mingw?
No, we simply never call it, because the UNIX socket code is also defined out.
+static virNetSocketPtr virNetSocketNew(virSocketAddrPtr localAddr, + virSocketAddrPtr remoteAddr, + bool isClient, + int fd, int errfd, pid_t pid) +{ + virNetSocketPtr sock; + int no_slow_start = 1; + + VIR_DEBUG("localAddr=%p remoteAddr=%p fd=%d errfd=%d pid=%d", + localAddr, remoteAddr, + fd, errfd, pid); + + if (virSetCloseExec(fd) < 0 || + virSetNonBlock(fd) < 0) + return NULL;
No error message? The virSet* functions are intentionally silent on error, but this helper function should probably always issue an error on all failure paths....
Yep, good point.
+ } + + if (localAddr) + sock->localAddr = *localAddr; + if (remoteAddr) + sock->remoteAddr = *remoteAddr; + sock->fd = fd; + sock->errfd = errfd; + sock->pid = pid; + + /* Disable nagle for TCP sockets */ + if (sock->localAddr.data.sa.sa_family == AF_INET || + sock->localAddr.data.sa.sa_family == AF_INET6) + setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, + &no_slow_start, + sizeof(no_slow_start));
We don't care if setsockopt failed?
Yes, we should report errors
+int virNetSocketNewListenTCP(const char *nodename, + const char *service, + virNetSocketPtr **retsocks, + size_t *nretsocks) +{ + virNetSocketPtr *socks = NULL; + size_t nsocks = 0; + struct addrinfo *ai = NULL; + struct addrinfo hints; + int fd = -1; + int i; + + *retsocks = NULL; + *nretsocks = 0; + + memset (&hints, 0, sizeof hints);
My earlier review comment about ' (' vs. '(' consistency in function calls still hasn't been addressed: http://www.redhat.com/archives/libvir-list/2010-December/msg00675.html
I got quite a few of them, but missed more.
+ int opt = 1; + setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof opt);
Again, no checks for setsockopt failures?
+error: + freeaddrinfo(ai); + for (i = 0 ; i < nsocks ; i++) + virNetSocketFree(socks[i]); + VIR_FREE(socks); + freeaddrinfo(ai);
Ouch - double freeaddrinfo - bound to segfault.
Opps.
+ oldmask = umask(~mask); + + if (bind(fd, &addr.data.sa, addr.len) < 0) { + virReportSystemError(errno, + _("Failed to bind socket to '%s'"), + path); + goto error; + } + umask(oldmask);
It's a shame that umask() is process-wide. This introduces a race window to other threads. Is this a case where we need another virFileOpenAs helper method, which forks, does the umask and bind in the child, then passes the fd back to the parent? But that's a question for another day, and doesn't affect the validity of this patch.
Fortunately socket bind takes place during main daemon startup when there is only 1 thread. In general though, it would be nice to have a way to address this.
+ + +#ifndef WIN32 +int virNetSocketNewConnectCommand(virCommandPtr cmd, + virNetSocketPtr *retsock) +{ + pid_t pid = 0; + int sv[2]; + int errfd[2]; + + *retsock = NULL; + + /* Fork off the external process. Use socketpair to create a private + * (unnamed) Unix domain socket to the child process so we don't have + * to faff around with two file descriptors (a la 'pipe(2)'). + */ + if (socketpair(PF_UNIX, SOCK_STREAM, 0, sv) < 0) { + virReportSystemError(errno, "%s", + _("unable to create socket pair")); + goto error; + } + + if (pipe(errfd) < 0) {
Should we set the parent's half of sv and errfd to cloexec?
Not sure its really worth it, given the rest of our codebase.
+int virNetSocketAccept(virNetSocketPtr sock, virNetSocketPtr *clientsock) +{ + int fd; + virSocketAddr localAddr; + virSocketAddr remoteAddr; + + *clientsock = NULL; + + memset(&localAddr, 0, sizeof(localAddr)); + memset(&remoteAddr, 0, sizeof(remoteAddr)); + + remoteAddr.len = sizeof(remoteAddr.data.stor); + if ((fd = accept(sock->fd, &remoteAddr.data.sa, &remoteAddr.len)) < 0) { + if (errno == ECONNABORTED || + errno == EAGAIN) + return 0;
As written, this function returns 0 for both retry and success, and -1 for all other failure; it is up to the caller to check whether *clientsock is NULL to know if a retry is needed. Should it return 0 for success and 1 for retry, to make it easier to use?
This does not actually mean 'retry'. These "errors" occur if the client closed its TCP connection between the time of poll() and accept(). So callers do something like this: if (virNetSocketAccept(sock, &client) < 0) return -1 if (!client) return 0; ...work with new client return 0;
diff --git a/src/rpc/virnetsocket.h b/src/rpc/virnetsocket.h new file mode 100644 index 0000000..c33b2e1 --- /dev/null +++ b/src/rpc/virnetsocket.h @@ -0,0 +1,107 @@ +/* + * virnetsocket.h: generic network socket handling + * + * Copyright (C) 2006-2010 Red Hat, Inc.
2011
No change to src/libvirt_private.syms to list all these new functions?
Nothing required this (yet) Daniel -- |: http://berrange.com -o- http://www.flickr.com/photos/dberrange/ :| |: http://libvirt.org -o- http://virt-manager.org :| |: http://autobuild.org -o- http://search.cpan.org/~danberr/ :| |: http://entangle-photo.org -o- http://live.gnome.org/gtk-vnc :|

This provides two modules for handling TLS * virNetTLSContext provides the process-wide state, in particular all the x509 credentials, DH params and x509 whitelists * virNetTLSSession provides the per-connection state, ie the TLS session itself. The virNetTLSContext provides APIs for validating a TLS session's x509 credentials. The virNetTLSSession includes APIs for performing the initial TLS handshake and sending/recving encrypted data * src/Makefile.am: Add to libvirt-net-rpc.la * src/rpc/virnettlscontext.c, src/rpc/virnettlscontext.h: Generic TLS handling code --- configure.ac | 2 +- po/POTFILES.in | 1 + src/Makefile.am | 5 +- src/rpc/virnettlscontext.c | 892 ++++++++++++++++++++++++++++++++++++++++++++ src/rpc/virnettlscontext.h | 100 +++++ 5 files changed, 998 insertions(+), 2 deletions(-) create mode 100644 src/rpc/virnettlscontext.c create mode 100644 src/rpc/virnettlscontext.h diff --git a/configure.ac b/configure.ac index 49403dd..81bad91 100644 --- a/configure.ac +++ b/configure.ac @@ -134,7 +134,7 @@ LIBS=$old_libs dnl Availability of various common headers (non-fatal if missing). AC_CHECK_HEADERS([pwd.h paths.h regex.h sys/syslimits.h sys/un.h \ sys/poll.h syslog.h mntent.h net/ethernet.h linux/magic.h \ - sys/un.h sys/syscall.h netinet/tcp.h]) + sys/un.h sys/syscall.h netinet/tcp.h fnmatch.h]) AC_CHECK_LIB([intl],[gettext],[]) diff --git a/po/POTFILES.in b/po/POTFILES.in index 2ce3bba..30c69d1 100644 --- a/po/POTFILES.in +++ b/po/POTFILES.in @@ -66,6 +66,7 @@ src/qemu/qemu_process.c src/remote/remote_driver.c src/rpc/virnetmessage.c src/rpc/virnetsocket.c +src/rpc/virnettlscontext.c src/secret/secret_driver.c src/security/security_apparmor.c src/security/security_dac.c diff --git a/src/Makefile.am b/src/Makefile.am index 332b6ac..351bf2a 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -1204,10 +1204,13 @@ noinst_LTLIBRARIES += libvirt-net-rpc.la libvirt_net_rpc_la_SOURCES = \ rpc/virnetmessage.h rpc/virnetmessage.c \ rpc/virnetprotocol.h rpc/virnetprotocol.c \ - rpc/virnetsocket.h rpc/virnetsocket.c + rpc/virnetsocket.h rpc/virnetsocket.c \ + rpc/virnettlscontext.h rpc/virnettlscontext.c libvirt_net_rpc_la_CFLAGS = \ + $(GNUTLS_CFLAGS) \ $(AM_CFLAGS) libvirt_net_rpc_la_LDFLAGS = \ + $(GNUTLS_LIBS) \ $(AM_LDFLAGS) \ $(CYGWIN_EXTRA_LDFLAGS) \ $(MINGW_EXTRA_LDFLAGS) diff --git a/src/rpc/virnettlscontext.c b/src/rpc/virnettlscontext.c new file mode 100644 index 0000000..d45afb6 --- /dev/null +++ b/src/rpc/virnettlscontext.c @@ -0,0 +1,892 @@ +/* + * virnettlscontext.c: TLS encryption/x509 handling + * + * Copyright (C) 2010 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + */ + +#include <config.h> + +#include <unistd.h> +#ifdef HAVE_FNMATCH_H +# include <fnmatch.h> +#endif +#include <stdlib.h> + +#include <gnutls/gnutls.h> +#include <gnutls/x509.h> +#include "gnutls_1_0_compat.h" + +#include "virnettlscontext.h" + +#include "memory.h" +#include "virterror_internal.h" +#include "util.h" +#include "logging.h" +#include "configmake.h" + +#define DH_BITS 1024 + +#define VIR_FROM_THIS VIR_FROM_RPC + +#define LIBVIRT_PKI_DIR SYSCONFDIR "/pki" +#define LIBVIRT_CACERT LIBVIRT_PKI_DIR "/CA/cacert.pem" +#define LIBVIRT_CACRL LIBVIRT_PKI_DIR "/CA/cacrl.pem" +#define LIBVIRT_CLIENTKEY LIBVIRT_PKI_DIR "/libvirt/private/clientkey.pem" +#define LIBVIRT_CLIENTCERT LIBVIRT_PKI_DIR "/libvirt/clientcert.pem" +#define LIBVIRT_SERVERKEY LIBVIRT_PKI_DIR "/libvirt/private/serverkey.pem" +#define LIBVIRT_SERVERCERT LIBVIRT_PKI_DIR "/libvirt/servercert.pem" + +#define virNetError(code, ...) \ + virReportErrorHelper(NULL, VIR_FROM_RPC, code, __FILE__, \ + __FUNCTION__, __LINE__, __VA_ARGS__) + +struct _virNetTLSContext { + int refs; + + gnutls_certificate_credentials_t x509cred; + gnutls_dh_params_t dhParams; + + bool isServer; + bool requireValidCert; + const char *const*x509dnWhitelist; +}; + +struct _virNetTLSSession { + int refs; + + bool handshakeComplete; + + char *hostname; + gnutls_session_t session; + virNetTLSSessionWriteFunc writeFunc; + virNetTLSSessionReadFunc readFunc; + void *opaque; +}; + + +static int +virNetTLSContextCheckCertFile(const char *type, const char *file, bool allowMissing) +{ + if (!virFileExists(file)) { + if (allowMissing) + return 1; + + virReportSystemError(errno, + _("Cannot read %s '%s'"), + type, file); + return -1; + } + return 0; +} + + +static void virNetTLSLog(int level, const char *str) { + VIR_DEBUG("%d %s", level, str); +} + +static int virNetTLSContextLoadCredentials(virNetTLSContextPtr ctxt, + bool isServer, + const char *cacert, + const char *cacrl, + const char *cert, + const char *key) +{ + int ret = -1; + int err; + + if (cacert && cacert[0] != '\0') { + if (virNetTLSContextCheckCertFile("CA certificate", cacert, false) < 0) + goto cleanup; + + VIR_DEBUG("loading CA cert from %s", cacert); + err = gnutls_certificate_set_x509_trust_file(ctxt->x509cred, + cacert, + GNUTLS_X509_FMT_PEM); + if (err < 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Unable to set x509 CA certificate: %s: %s"), + cacert, gnutls_strerror (err)); + goto cleanup; + } + } + + if (cacrl && cacrl[0] != '\0') { + int rv; + if ((rv = virNetTLSContextCheckCertFile("CA revocation list", cacrl, true)) < 0) + goto cleanup; + + if (rv == 0) { + VIR_DEBUG("loading CRL from %s", cacrl); + err = gnutls_certificate_set_x509_crl_file(ctxt->x509cred, + cacrl, + GNUTLS_X509_FMT_PEM); + if (err < 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Unable to set x509 certificate revocation list: %s: %s"), + cacrl, gnutls_strerror(err)); + goto cleanup; + } + } else { + VIR_DEBUG("Skipping non-existant CA CRL %s", cacrl); + } + } + + if (cert && cert[0] != '\0' && key && key[0] != '\0') { + int rv; + if ((rv = virNetTLSContextCheckCertFile("certificate", cert, !isServer)) < 0) + goto cleanup; + if (rv == 0 && + (rv = virNetTLSContextCheckCertFile("private key", key, !isServer)) < 0) + goto cleanup; + + if (rv == 0) { + VIR_DEBUG("loading cert and key from %s and %s", cert, key); + err = + gnutls_certificate_set_x509_key_file(ctxt->x509cred, + cert, key, + GNUTLS_X509_FMT_PEM); + if (err < 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Unable to set x509 key and certificate: %s, %s: %s"), + key, cert, gnutls_strerror (err)); + goto cleanup; + } + } else { + VIR_DEBUG("Skipping non-existant cert %s key %s on client", cert, key); + } + } + + ret = 0; + +cleanup: + return ret; +} + + +static virNetTLSContextPtr virNetTLSContextNew(const char *cacert, + const char *cacrl, + const char *cert, + const char *key, + const char *const*x509dnWhitelist, + bool requireValidCert, + bool isServer) +{ + virNetTLSContextPtr ctxt; + char *gnutlsdebug; + int err; + + VIR_DEBUG("cacert=%s cacrl=%s cert=%s key=%s requireValid=%d isServer=%d", + cacert, NULLSTR(cacrl), cert, key, requireValidCert, isServer); + + if (VIR_ALLOC(ctxt) < 0) { + virReportOOMError(); + return NULL; + } + + ctxt->refs = 1; + + /* Initialise GnuTLS. */ + gnutls_global_init(); + + if ((gnutlsdebug = getenv("LIBVIRT_GNUTLS_DEBUG")) != NULL) { + int val; + if (virStrToLong_i(gnutlsdebug, NULL, 10, &val) < 0) + val = 10; + gnutls_global_set_log_level(val); + gnutls_global_set_log_function(virNetTLSLog); + VIR_DEBUG0("Enabled GNUTLS debug"); + } + + + err = gnutls_certificate_allocate_credentials(&ctxt->x509cred); + if (err) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Unable to allocate x509 credentials: %s"), + gnutls_strerror (err)); + goto error; + } + + if (virNetTLSContextLoadCredentials(ctxt, isServer, cacert, cacrl, cert, key) < 0) + goto error; + + /* Generate Diffie Hellman parameters - for use with DHE + * kx algorithms. These should be discarded and regenerated + * once a day, once a week or once a month. Depending on the + * security requirements. + */ + if (isServer) { + err = gnutls_dh_params_init(&ctxt->dhParams); + if (err < 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Unable to initialize diffie-hellman parameters: %s"), + gnutls_strerror (err)); + goto error; + } + err = gnutls_dh_params_generate2(ctxt->dhParams, DH_BITS); + if (err < 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Unable to generate diffie-hellman parameters: %s"), + gnutls_strerror (err)); + goto error; + } + + gnutls_certificate_set_dh_params(ctxt->x509cred, + ctxt->dhParams); + } + + ctxt->requireValidCert = requireValidCert; + ctxt->x509dnWhitelist = x509dnWhitelist; + ctxt->isServer = isServer; + + return ctxt; + +error: + if (isServer) + gnutls_dh_params_deinit(ctxt->dhParams); + gnutls_certificate_free_credentials(ctxt->x509cred); + VIR_FREE(ctxt); + return NULL; +} + + +static int virNetTLSContextLocateCredentials(const char *pkipath, + bool tryUserPkiPath, + bool isServer, + char **cacert, + char **cacrl, + char **cert, + char **key) +{ + char *userdir = NULL; + char *user_pki_path = NULL; + + *cacert = NULL; + *cacrl = NULL; + *key = NULL; + *cert = NULL; + + VIR_DEBUG("pkipath=%s isServer=%d tryUserPkiPath=%d", + pkipath, isServer, tryUserPkiPath); + + /* Explicit path, then use that no matter whether the + * files actually exist there + */ + if (pkipath) { + VIR_DEBUG("Told to use TLS credentials in %s", pkipath); + if ((virAsprintf(cacert, "%s/%s", pkipath, + "cacert.pem")) < 0) + goto out_of_memory; + if ((virAsprintf(cacrl, "%s/%s", pkipath, + "cacrl.pem")) < 0) + goto out_of_memory; + if ((virAsprintf(key, "%s/%s", pkipath, + isServer ? "serverkey.pem" : "clientkey.pem")) < 0) + goto out_of_memory; + + if ((virAsprintf(cert, "%s/%s", pkipath, + isServer ? "servercert.pem" : "clientcert.pem")) < 0) + goto out_of_memory; + } else if (tryUserPkiPath) { + /* Check to see if $HOME/.pki contains at least one of the + * files and if so, use that + */ + userdir = virGetUserDirectory(getuid()); + + if (!userdir) + goto out_of_memory; + + if (virAsprintf(&user_pki_path, "%s/.pki/libvirt", userdir) < 0) + goto out_of_memory; + + VIR_DEBUG("Trying to find TLS user credentials in %s", user_pki_path); + + if ((virAsprintf(cacert, "%s/%s", user_pki_path, + "cacert.pem")) < 0) + goto out_of_memory; + + if ((virAsprintf(cacrl, "%s/%s", user_pki_path, + "cacrl.pem")) < 0) + goto out_of_memory; + + if ((virAsprintf(key, "%s/%s", user_pki_path, + isServer ? "serverkey.pem" : "clientkey.pem")) < 0) + goto out_of_memory; + + if ((virAsprintf(cert, "%s/%s", user_pki_path, + isServer ? "servercert.pem" : "clientcert.pem")) < 0) + goto out_of_memory; + + /* + * If one of CA cert can't be found then + * fallback to global default. Don't check + * for client cert/key, since they're optional + * in any case + */ + if (!virFileExists(*cacert)) { + VIR_FREE(*cacert); + VIR_FREE(*cacrl); + VIR_FREE(*key); + VIR_FREE(*cert); + } + } + + /* No explicit path, or user path didn't exist, so + * fallback to global defaults + */ + if (!*cacert) { + VIR_DEBUG0("Using default TLS credential paths"); + if (!(*cacert = strdup(LIBVIRT_CACERT))) + goto out_of_memory; + + if (!(*cacrl = strdup(LIBVIRT_CACRL))) + goto out_of_memory; + + if (!(*key = strdup(isServer ? LIBVIRT_SERVERKEY : LIBVIRT_CLIENTKEY))) + goto out_of_memory; + + if (!(*cert = strdup(isServer ? LIBVIRT_SERVERCERT : LIBVIRT_CLIENTCERT))) + goto out_of_memory; + } + + VIR_FREE(user_pki_path); + VIR_FREE(userdir); + + return 0; + +out_of_memory: + virReportOOMError(); + VIR_FREE(*cacert); + VIR_FREE(*cacrl); + VIR_FREE(*key); + VIR_FREE(*cert); + VIR_FREE(user_pki_path); + VIR_FREE(userdir); + return -1; +} + + +static virNetTLSContextPtr virNetTLSContextNewPath(const char *pkipath, + bool tryUserPkiPath, + const char *const*x509dnWhitelist, + bool requireValidCert, + bool isServer) +{ + char *cacert = NULL, *cacrl = NULL, *key = NULL, *cert = NULL; + virNetTLSContextPtr ctxt = NULL; + + if (virNetTLSContextLocateCredentials(pkipath, tryUserPkiPath, isServer, + &cacert, &cacrl, &key, &cert) < 0) + return NULL; + + ctxt = virNetTLSContextNew(cacert, cacrl, key, cert, + x509dnWhitelist, requireValidCert, isServer); + + VIR_FREE(cacert); + VIR_FREE(cacrl); + VIR_FREE(key); + VIR_FREE(cert); + + return ctxt; +} + +virNetTLSContextPtr virNetTLSContextNewServerPath(const char *pkipath, + bool tryUserPkiPath, + const char *const*x509dnWhitelist, + bool requireValidCert) +{ + return virNetTLSContextNewPath(pkipath, tryUserPkiPath, + x509dnWhitelist, requireValidCert, true); +} + +virNetTLSContextPtr virNetTLSContextNewClientPath(const char *pkipath, + bool tryUserPkiPath, + bool requireValidCert) +{ + return virNetTLSContextNewPath(pkipath, tryUserPkiPath, + NULL, requireValidCert, false); +} + + +virNetTLSContextPtr virNetTLSContextNewServer(const char *cacert, + const char *cacrl, + const char *cert, + const char *key, + const char *const*x509dnWhitelist, + bool requireValidCert) +{ + return virNetTLSContextNew(cacert, cacrl, key, cert, + x509dnWhitelist, requireValidCert, true); +} + + +virNetTLSContextPtr virNetTLSContextNewClient(const char *cacert, + const char *cacrl, + const char *cert, + const char *key, + bool requireValidCert) +{ + return virNetTLSContextNew(cacert, cacrl, key, cert, + NULL, requireValidCert, false); +} + + +void virNetTLSContextRef(virNetTLSContextPtr ctxt) +{ + ctxt->refs++; +} + + +/* Check DN is on tls_allowed_dn_list. */ +static int +virNetTLSContextCheckDN(virNetTLSContextPtr ctxt, + const char *dname) +{ + const char *const*wildcards; + + /* If the list is not set, allow any DN. */ + wildcards = ctxt->x509dnWhitelist; + if (!wildcards) + return 1; + + while (*wildcards) { +#ifdef HAVE_FNMATCH_H + int ret = fnmatch (*wildcards, dname, 0); + if (ret == 0) /* Succesful match */ + return 1; + if (ret != FNM_NOMATCH) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("Malformed TLS whitelist regular expression '%s'"), + *wildcards); + return -1; + } +#else + if (STREQ(*wildcards, dname)) + return 1; +#endif + + wildcards++; + } + + /* Log the client's DN for debugging */ + VIR_DEBUG(_("Failed whitelist check for client DN '%s'"), dname); + + /* This is the most common error: make it informative. */ + virNetError(VIR_ERR_SYSTEM_ERROR, "%s", + _("Client's Distinguished Name is not on the list " + "of allowed clients (tls_allowed_dn_list). Use " + "'certtool -i --infile clientcert.pem' to view the" + "Distinguished Name field in the client certificate," + "or run this daemon with --verbose option.")); + return 0; +} + +static int virNetTLSContextValidCertificate(virNetTLSContextPtr ctxt, + virNetTLSSessionPtr sess) +{ + int ret; + unsigned int status; + const gnutls_datum_t *certs; + unsigned int nCerts, i; + time_t now; + char name[256]; + size_t namesize = sizeof name; + + memset(name, 0, namesize); + + if ((ret = gnutls_certificate_verify_peers2(sess->session, &status)) < 0){ + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Unable to verify TLS peer: %s"), + gnutls_strerror(ret)); + goto authdeny; + } + + if ((now = time(NULL)) == ((time_t)-1)) { + virReportSystemError(errno, "%s", + _("cannot get current time")); + goto authfail; + } + + if (status != 0) { + const char *reason = _("Invalid certificate"); + + if (status & GNUTLS_CERT_INVALID) + reason = _("The certificate is not trusted."); + + if (status & GNUTLS_CERT_SIGNER_NOT_FOUND) + reason = _("The certificate hasn't got a known issuer."); + + if (status & GNUTLS_CERT_REVOKED) + reason = _("The certificate has been revoked."); + +#ifndef GNUTLS_1_0_COMPAT + if (status & GNUTLS_CERT_INSECURE_ALGORITHM) + reason = _("The certificate uses an insecure algorithm"); +#endif + + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Certificate failed validation: %s"), + reason); + goto authdeny; + } + + if (gnutls_certificate_type_get(sess->session) != GNUTLS_CRT_X509) { + virNetError(VIR_ERR_SYSTEM_ERROR, "%s", + _("Only x509 certificates are supported")); + goto authdeny; + } + + if (!(certs = gnutls_certificate_get_peers(sess->session, &nCerts))) { + virNetError(VIR_ERR_SYSTEM_ERROR, "%s", + _("The certificate has no peers")); + goto authdeny; + } + + for (i = 0; i < nCerts; i++) { + gnutls_x509_crt_t cert; + + if (gnutls_x509_crt_init (&cert) < 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, "%s", + _("Unable to initialize certificate")); + goto authfail; + } + + if (gnutls_x509_crt_import(cert, &certs[i], GNUTLS_X509_FMT_DER) < 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, "%s", + _("Unable to load certificate")); + gnutls_x509_crt_deinit(cert); + goto authfail; + } + + if (gnutls_x509_crt_get_expiration_time(cert) < now) { + virNetError(VIR_ERR_SYSTEM_ERROR, "%s", + _("The client certificate has expired")); + gnutls_x509_crt_deinit(cert); + goto authdeny; + } + + if (gnutls_x509_crt_get_activation_time(cert) > now) { + virNetError(VIR_ERR_SYSTEM_ERROR, "%s", + _("The client certificate is not yet active")); + gnutls_x509_crt_deinit(cert); + goto authdeny; + } + + if (i == 0) { + ret = gnutls_x509_crt_get_dn(cert, name, &namesize); + if (ret != 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Failed to get certificate distinguished name: %s"), + gnutls_strerror(ret)); + gnutls_x509_crt_deinit(cert); + goto authfail; + } + + if (virNetTLSContextCheckDN(ctxt, name) <= 0) { + gnutls_x509_crt_deinit(cert); + goto authdeny; + } + + if (sess->hostname && + !gnutls_x509_crt_check_hostname(cert, sess->hostname)) { + virNetError(VIR_ERR_RPC, + _("Certificate's owner does not match the hostname (%s)"), + sess->hostname); + gnutls_x509_crt_deinit(cert); + goto authdeny; + } + } + } + +#if 0 + PROBE(CLIENT_TLS_ALLOW, "fd=%d, name=%s", + virNetServerClientGetFD(client), name); +#endif + return 0; + +authdeny: +#if 0 + PROBE(CLIENT_TLS_DENY, "fd=%d, name=%s", + virNetServerClientGetFD(client), name); +#endif + return -1; + +authfail: +#if 0 + PROBE(CLIENT_TLS_FAIL, "fd=%d", + virNetServerClientGetFD(client)); +#endif + return -1; +} + +int virNetTLSContextCheckCertificate(virNetTLSContextPtr ctxt, + virNetTLSSessionPtr sess) { + if (virNetTLSContextValidCertificate(ctxt, sess) < 0) { + if (ctxt->requireValidCert) { + virNetError(VIR_ERR_AUTH_FAILED, "%s", + _("Failed to verify peer's certificate")); + return -1; + } + VIR_INFO0(_("Ignoring bad certificate at user request")); + } + return 0; +} + +void virNetTLSContextFree(virNetTLSContextPtr ctxt) +{ + if (!ctxt) + return; + + ctxt->refs--; + if (ctxt->refs > 0) + return; + + gnutls_dh_params_deinit(ctxt->dhParams); + gnutls_certificate_free_credentials(ctxt->x509cred); + VIR_FREE(ctxt); +} + + + +static ssize_t +virNetTLSSessionPush(void *opaque, const void *buf, size_t len) +{ + virNetTLSSessionPtr sess = opaque; + if (!sess->writeFunc) { + VIR_WARN0("TLS session push with missing read function"); + errno = EIO; + return -1; + }; + + return sess->writeFunc(buf, len, sess->opaque); +} + + +static ssize_t +virNetTLSSessionPull(void *opaque, void *buf, size_t len) +{ + virNetTLSSessionPtr sess = opaque; + if (!sess->readFunc) { + VIR_WARN0("TLS session pull with missing read function"); + errno = EIO; + return -1; + }; + + return sess->readFunc(buf, len, sess->opaque); +} + + +virNetTLSSessionPtr virNetTLSSessionNew(virNetTLSContextPtr ctxt, + const char *hostname) +{ + virNetTLSSessionPtr sess; + int err; + static const int cert_type_priority[] = { GNUTLS_CRT_X509, 0 }; + + VIR_DEBUG("ctxt=%p hostname=%s isServer=%d", ctxt, NULLSTR(hostname), ctxt->isServer); + + if (VIR_ALLOC(sess) < 0) { + virReportOOMError(); + return NULL; + } + + sess->refs = 1; + if (hostname && + !(sess->hostname = strdup(hostname))) { + virReportOOMError(); + goto error; + } + + if ((err = gnutls_init(&sess->session, + ctxt->isServer ? GNUTLS_SERVER : GNUTLS_CLIENT)) != 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Failed to initialize TLS session: %s"), + gnutls_strerror(err)); + goto error; + } + + /* avoid calling all the priority functions, since the defaults + * are adequate. + */ + if ((err = gnutls_set_default_priority(sess->session)) != 0 || + (err = gnutls_certificate_type_set_priority(sess->session, + cert_type_priority))) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Failed to set TLS session priority %s"), + gnutls_strerror(err)); + goto error; + } + + if ((err = gnutls_credentials_set(sess->session, + GNUTLS_CRD_CERTIFICATE, + ctxt->x509cred)) != 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Failed set TLS x509 credentials: %s"), + gnutls_strerror(err)); + goto error; + } + + /* request client certificate if any. + */ + if (ctxt->isServer) { + gnutls_certificate_server_set_request(sess->session, GNUTLS_CERT_REQUEST); + + gnutls_dh_set_prime_bits(sess->session, DH_BITS); + } + + gnutls_transport_set_ptr(sess->session, sess); + gnutls_transport_set_push_function(sess->session, + virNetTLSSessionPush); + gnutls_transport_set_pull_function(sess->session, + virNetTLSSessionPull); + + return sess; + +error: + virNetTLSSessionFree(sess); + return NULL; +} + + +void virNetTLSSessionRef(virNetTLSSessionPtr sess) +{ + sess->refs++; +} + +void virNetTLSSessionSetIOCallbacks(virNetTLSSessionPtr sess, + virNetTLSSessionWriteFunc writeFunc, + virNetTLSSessionReadFunc readFunc, + void *opaque) +{ + sess->writeFunc = writeFunc; + sess->readFunc = readFunc; + sess->opaque = opaque; +} + + +ssize_t virNetTLSSessionWrite(virNetTLSSessionPtr sess, + const char *buf, size_t len) +{ + ssize_t ret; + ret = gnutls_record_send(sess->session, buf, len); + + if (ret >= 0) + return ret; + + switch (ret) { + case GNUTLS_E_AGAIN: + errno = EAGAIN; + break; + case GNUTLS_E_INTERRUPTED: + errno = EINTR; + break; + default: + errno = EIO; + break; + } + + return -1; +} + +ssize_t virNetTLSSessionRead(virNetTLSSessionPtr sess, + char *buf, size_t len) +{ + ssize_t ret; + + ret = gnutls_record_recv(sess->session, buf, len); + + if (ret >= 0) + return ret; + + switch (ret) { + case GNUTLS_E_AGAIN: + errno = EAGAIN; + break; + case GNUTLS_E_INTERRUPTED: + errno = EINTR; + break; + default: + errno = EIO; + break; + } + + return -1; +} + +int virNetTLSSessionHandshake(virNetTLSSessionPtr sess) +{ + VIR_DEBUG("sess=%p", sess); + int ret = gnutls_handshake(sess->session); + VIR_DEBUG("Ret=%d", ret); + if (ret == 0) { + sess->handshakeComplete = true; + VIR_DEBUG0("Handshake is complete"); + return 0; + } + if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) + return 1; + +#if 0 + PROBE(CLIENT_TLS_FAIL, "fd=%d", + virNetServerClientGetFD(client)); +#endif + + virNetError(VIR_ERR_AUTH_FAILED, + _("TLS handshake failed %s"), + gnutls_strerror (ret)); + return -1; +} + +virNetTLSSessionHandshakeStatus +virNetTLSSessionGetHandshakeStatus(virNetTLSSessionPtr sess) +{ + if (sess->handshakeComplete) + return VIR_NET_TLS_HANDSHAKE_COMPLETE; + else if (gnutls_record_get_direction (sess->session) == 0) + return VIR_NET_TLS_HANDSHAKE_RECVING; + else + return VIR_NET_TLS_HANDSHAKE_SENDING; +} + +int virNetTLSSessionGetKeySize(virNetTLSSessionPtr sess) +{ + gnutls_cipher_algorithm_t cipher; + int ssf; + + cipher = gnutls_cipher_get(sess->session); + if (!(ssf = gnutls_cipher_get_key_size(cipher))) { + virNetError(VIR_ERR_INTERNAL_ERROR, "%s", + _("invalid cipher size for TLS session")); + return -1; + } + + return ssf; +} + + +void virNetTLSSessionFree(virNetTLSSessionPtr sess) +{ + if (!sess) + return; + + sess->refs--; + if (sess->refs > 0) + return; + + VIR_FREE(sess->hostname); + gnutls_deinit(sess->session); + VIR_FREE(sess); +} diff --git a/src/rpc/virnettlscontext.h b/src/rpc/virnettlscontext.h new file mode 100644 index 0000000..258ca4b --- /dev/null +++ b/src/rpc/virnettlscontext.h @@ -0,0 +1,100 @@ +/* + * virnettlscontext.h: TLS encryption/x509 handling + * + * Copyright (C) 2010 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + */ + +#ifndef __VIR_NET_TLS_CONTEXT_H__ +# define __VIR_NET_TLS_CONTEXT_H__ + +# include <stdbool.h> +# include <sys/types.h> + +typedef struct _virNetTLSContext virNetTLSContext; +typedef virNetTLSContext *virNetTLSContextPtr; + +typedef struct _virNetTLSSession virNetTLSSession; +typedef virNetTLSSession *virNetTLSSessionPtr; + + +virNetTLSContextPtr virNetTLSContextNewServerPath(const char *pkipath, + bool tryUserPkiPath, + const char *const*x509dnWhitelist, + bool requireValidCert); + +virNetTLSContextPtr virNetTLSContextNewClientPath(const char *pkipath, + bool tryUserPkiPath, + bool requireValidCert); + +virNetTLSContextPtr virNetTLSContextNewServer(const char *cacert, + const char *cacrl, + const char *cert, + const char *key, + const char *const*x509dnWhitelist, + bool requireValidCert); + +virNetTLSContextPtr virNetTLSContextNewClient(const char *cacert, + const char *cacrl, + const char *cert, + const char *key, + bool requireValidCert); + +void virNetTLSContextRef(virNetTLSContextPtr ctxt); + +int virNetTLSContextCheckCertificate(virNetTLSContextPtr ctxt, + virNetTLSSessionPtr sess); + +void virNetTLSContextFree(virNetTLSContextPtr ctxt); + + +typedef ssize_t (*virNetTLSSessionWriteFunc)(const char *buf, size_t len, + void *opaque); +typedef ssize_t (*virNetTLSSessionReadFunc)(char *buf, size_t len, + void *opaque); + +virNetTLSSessionPtr virNetTLSSessionNew(virNetTLSContextPtr ctxt, + const char *hostname); + +void virNetTLSSessionSetIOCallbacks(virNetTLSSessionPtr sess, + virNetTLSSessionWriteFunc writeFunc, + virNetTLSSessionReadFunc readFunc, + void *opaque); + +void virNetTLSSessionRef(virNetTLSSessionPtr sess); + +ssize_t virNetTLSSessionWrite(virNetTLSSessionPtr sess, + const char *buf, size_t len); +ssize_t virNetTLSSessionRead(virNetTLSSessionPtr sess, + char *buf, size_t len); + +int virNetTLSSessionHandshake(virNetTLSSessionPtr sess); + +typedef enum { + VIR_NET_TLS_HANDSHAKE_COMPLETE, + VIR_NET_TLS_HANDSHAKE_SENDING, + VIR_NET_TLS_HANDSHAKE_RECVING, +} virNetTLSSessionHandshakeStatus; + +virNetTLSSessionHandshakeStatus +virNetTLSSessionGetHandshakeStatus(virNetTLSSessionPtr sess); + +int virNetTLSSessionGetKeySize(virNetTLSSessionPtr sess); + +void virNetTLSSessionFree(virNetTLSSessionPtr sess); + + +#endif -- 1.7.4

On 03/15/2011 11:51 AM, Daniel P. Berrange wrote:
This provides two modules for handling TLS
* virNetTLSContext provides the process-wide state, in particular all the x509 credentials, DH params and x509 whitelists * virNetTLSSession provides the per-connection state, ie the TLS session itself.
The virNetTLSContext provides APIs for validating a TLS session's x509 credentials. The virNetTLSSession includes APIs for performing the initial TLS handshake and sending/recving encrypted data
* src/Makefile.am: Add to libvirt-net-rpc.la * src/rpc/virnettlscontext.c, src/rpc/virnettlscontext.h: Generic TLS handling code --- configure.ac | 2 +- po/POTFILES.in | 1 + src/Makefile.am | 5 +- src/rpc/virnettlscontext.c | 892 ++++++++++++++++++++++++++++++++++++++++++++ src/rpc/virnettlscontext.h | 100 +++++ 5 files changed, 998 insertions(+), 2 deletions(-) create mode 100644 src/rpc/virnettlscontext.c create mode 100644 src/rpc/virnettlscontext.h
No src/libvirt_private.syms entries?
diff --git a/configure.ac b/configure.ac index 49403dd..81bad91 100644 --- a/configure.ac +++ b/configure.ac @@ -134,7 +134,7 @@ LIBS=$old_libs dnl Availability of various common headers (non-fatal if missing). AC_CHECK_HEADERS([pwd.h paths.h regex.h sys/syslimits.h sys/un.h \ sys/poll.h syslog.h mntent.h net/ethernet.h linux/magic.h \ - sys/un.h sys/syscall.h netinet/tcp.h]) + sys/un.h sys/syscall.h netinet/tcp.h fnmatch.h])
Gnulib provides fnmatch. We shouldn't be adding this check, but modify bootstrap.conf instead.
+++ b/src/rpc/virnettlscontext.c @@ -0,0 +1,892 @@ +/* + * virnettlscontext.c: TLS encryption/x509 handling + * + * Copyright (C) 2010 Red Hat, Inc.
2011
+#include <config.h> + +#include <unistd.h> +#ifdef HAVE_FNMATCH_H +# include <fnmatch.h> +#endif
This should be unconditional inclusion, thanks to gnulib.
+ +static int virNetTLSContextLoadCredentials(virNetTLSContextPtr ctxt, + bool isServer, + const char *cacert, + const char *cacrl, + const char *cert, + const char *key) +{ + int ret = -1; + int err; + + if (cacert && cacert[0] != '\0') { + if (virNetTLSContextCheckCertFile("CA certificate", cacert, false) < 0) + goto cleanup; + + VIR_DEBUG("loading CA cert from %s", cacert); + err = gnutls_certificate_set_x509_trust_file(ctxt->x509cred, + cacert, + GNUTLS_X509_FMT_PEM); + if (err < 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Unable to set x509 CA certificate: %s: %s"), + cacert, gnutls_strerror (err));
Consistency on ' (' vs. '(' for function calls.
+ } else { + VIR_DEBUG("Skipping non-existant cert %s key %s on client", cert, key);
s/existant/existent/
+ +/* Check DN is on tls_allowed_dn_list. */ +static int +virNetTLSContextCheckDN(virNetTLSContextPtr ctxt, + const char *dname) +{ + const char *const*wildcards; + + /* If the list is not set, allow any DN. */ + wildcards = ctxt->x509dnWhitelist; + if (!wildcards) + return 1; + + while (*wildcards) { +#ifdef HAVE_FNMATCH_H + int ret = fnmatch (*wildcards, dname, 0);
Use this unconditionally.
+ +#if 0 + PROBE(CLIENT_TLS_ALLOW, "fd=%d, name=%s", + virNetServerClientGetFD(client), name); +#endif + return 0;
Are these PROBE() statements worth keeping? Are they for debug, for systemtap probe points, or something else?
--- /dev/null +++ b/src/rpc/virnettlscontext.h @@ -0,0 +1,100 @@ +/* + * virnettlscontext.h: TLS encryption/x509 handling + * + * Copyright (C) 2010 Red Hat, Inc.
2011
+#ifndef __VIR_NET_TLS_CONTEXT_H__ +# define __VIR_NET_TLS_CONTEXT_H__ + +# include <stdbool.h>
Is this redundant, now that "internal.h" guarantees this and all .c files should be including "internal.h"? I don't see any other headers that include <stdbool.h> since commit 3541672.
+ +void virNetTLSSessionFree(virNetTLSSessionPtr sess);
Should cfg.mk list this as a free-like function? -- Eric Blake eblake@redhat.com +1-801-349-2682 Libvirt virtualization library http://libvirt.org

On Tue, Mar 15, 2011 at 04:34:33PM -0600, Eric Blake wrote:
On 03/15/2011 11:51 AM, Daniel P. Berrange wrote:
This provides two modules for handling TLS
* virNetTLSContext provides the process-wide state, in particular all the x509 credentials, DH params and x509 whitelists * virNetTLSSession provides the per-connection state, ie the TLS session itself.
The virNetTLSContext provides APIs for validating a TLS session's x509 credentials. The virNetTLSSession includes APIs for performing the initial TLS handshake and sending/recving encrypted data
* src/Makefile.am: Add to libvirt-net-rpc.la * src/rpc/virnettlscontext.c, src/rpc/virnettlscontext.h: Generic TLS handling code --- configure.ac | 2 +- po/POTFILES.in | 1 + src/Makefile.am | 5 +- src/rpc/virnettlscontext.c | 892 ++++++++++++++++++++++++++++++++++++++++++++ src/rpc/virnettlscontext.h | 100 +++++ 5 files changed, 998 insertions(+), 2 deletions(-) create mode 100644 src/rpc/virnettlscontext.c create mode 100644 src/rpc/virnettlscontext.h
No src/libvirt_private.syms entries?
diff --git a/configure.ac b/configure.ac index 49403dd..81bad91 100644 --- a/configure.ac +++ b/configure.ac @@ -134,7 +134,7 @@ LIBS=$old_libs dnl Availability of various common headers (non-fatal if missing). AC_CHECK_HEADERS([pwd.h paths.h regex.h sys/syslimits.h sys/un.h \ sys/poll.h syslog.h mntent.h net/ethernet.h linux/magic.h \ - sys/un.h sys/syscall.h netinet/tcp.h]) + sys/un.h sys/syscall.h netinet/tcp.h fnmatch.h])
Gnulib provides fnmatch. We shouldn't be adding this check, but modify bootstrap.conf instead.
Ah, I didn't know this. We already use fnmatch in libvirtd, but hadn't added gnulib module for it.
+ +#if 0 + PROBE(CLIENT_TLS_ALLOW, "fd=%d, name=%s", + virNetServerClientGetFD(client), name); +#endif + return 0;
Are these PROBE() statements worth keeping? Are they for debug, for systemtap probe points, or something else?
They're an item I need to fix before I finally convert libvirtd. I will address that as a followup patch though once the generic code is committed.
+#ifndef __VIR_NET_TLS_CONTEXT_H__ +# define __VIR_NET_TLS_CONTEXT_H__ + +# include <stdbool.h>
Is this redundant, now that "internal.h" guarantees this and all .c files should be including "internal.h"? I don't see any other headers that include <stdbool.h> since commit 3541672.
Yes, I forgot to remove this one Daniel -- |: http://berrange.com -o- http://www.flickr.com/photos/dberrange/ :| |: http://libvirt.org -o- http://virt-manager.org :| |: http://autobuild.org -o- http://search.cpan.org/~danberr/ :| |: http://entangle-photo.org -o- http://live.gnome.org/gtk-vnc :|

This provides two modules for handling SASL * virNetSASLContext provides the process-wide state, currently just a whitelist of usernames on the server and a one time library init call * virNetTLSSession provides the per-connection state, ie the SASL session itself. This also include APIs for providing data encryption/decryption once the session is established * src/Makefile.am: Add to libvirt-net-rpc.la * src/rpc/virnetsaslcontext.c, src/rpc/virnetsaslcontext.h: Generic SASL handling code --- po/POTFILES.in | 1 + src/Makefile.am | 9 + src/rpc/virnetsaslcontext.c | 606 +++++++++++++++++++++++++++++++++++++++++++ src/rpc/virnetsaslcontext.h | 120 +++++++++ 4 files changed, 736 insertions(+), 0 deletions(-) create mode 100644 src/rpc/virnetsaslcontext.c create mode 100644 src/rpc/virnetsaslcontext.h diff --git a/po/POTFILES.in b/po/POTFILES.in index 30c69d1..53d63a8 100644 --- a/po/POTFILES.in +++ b/po/POTFILES.in @@ -65,6 +65,7 @@ src/qemu/qemu_monitor_text.c src/qemu/qemu_process.c src/remote/remote_driver.c src/rpc/virnetmessage.c +src/rpc/virnetsaslcontext.c src/rpc/virnetsocket.c src/rpc/virnettlscontext.c src/secret/secret_driver.c diff --git a/src/Makefile.am b/src/Makefile.am index 351bf2a..5d20d63 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -1206,11 +1206,20 @@ libvirt_net_rpc_la_SOURCES = \ rpc/virnetprotocol.h rpc/virnetprotocol.c \ rpc/virnetsocket.h rpc/virnetsocket.c \ rpc/virnettlscontext.h rpc/virnettlscontext.c +if HAVE_SASL +libvirt_net_rpc_la_SOURCES += \ + rpc/virnetsaslcontext.h rpc/virnetsaslcontext.c +else +EXTRA_DIST += \ + rpc/virnetsaslcontext.h rpc/virnetsaslcontext.c +endif libvirt_net_rpc_la_CFLAGS = \ $(GNUTLS_CFLAGS) \ + $(SASL_CFLAGS) \ $(AM_CFLAGS) libvirt_net_rpc_la_LDFLAGS = \ $(GNUTLS_LIBS) \ + $(SASL_LIBS) \ $(AM_LDFLAGS) \ $(CYGWIN_EXTRA_LDFLAGS) \ $(MINGW_EXTRA_LDFLAGS) diff --git a/src/rpc/virnetsaslcontext.c b/src/rpc/virnetsaslcontext.c new file mode 100644 index 0000000..c84cd6e --- /dev/null +++ b/src/rpc/virnetsaslcontext.c @@ -0,0 +1,606 @@ +/* + * virnetsaslcontext.c: SASL encryption/auth handling + * + * Copyright (C) 2010 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + */ + +#include <config.h> + +#ifdef HAVE_FNMATCH_H +# include <fnmatch.h> +#endif + +#include "virnetsaslcontext.h" +#include "virnetmessage.h" + +#include "virterror_internal.h" +#include "memory.h" +#include "logging.h" + +#define VIR_FROM_THIS VIR_FROM_RPC + +#define virNetError(code, ...) \ + virReportErrorHelper(NULL, VIR_FROM_RPC, code, __FILE__, \ + __FUNCTION__, __LINE__, __VA_ARGS__) + + +struct _virNetSASLContext { + const char *const*usernameWhitelist; + int refs; +}; + +struct _virNetSASLSession { + sasl_conn_t *conn; + int refs; + size_t maxbufsize; +}; + + +virNetSASLContextPtr virNetSASLContextNewClient(void) +{ + virNetSASLContextPtr ctxt; + int err; + + err = sasl_client_init(NULL); + if (err != SASL_OK) { + virNetError(VIR_ERR_AUTH_FAILED, + _("failed to initialize SASL library: %d (%s)"), + err, sasl_errstring(err, NULL, NULL)); + return NULL; + } + + if (VIR_ALLOC(ctxt) < 0) { + virReportOOMError(); + return NULL; + } + + ctxt->refs = 1; + + return ctxt; +} + +virNetSASLContextPtr virNetSASLContextNewServer(const char *const*usernameWhitelist) +{ + virNetSASLContextPtr ctxt; + int err; + + err = sasl_server_init(NULL, "libvirt"); + if (err != SASL_OK) { + virNetError(VIR_ERR_AUTH_FAILED, + _("failed to initialize SASL library: %d (%s)"), + err, sasl_errstring(err, NULL, NULL)); + return NULL; + } + + if (VIR_ALLOC(ctxt) < 0) { + virReportOOMError(); + return NULL; + } + + ctxt->usernameWhitelist = usernameWhitelist; + ctxt->refs = 1; + + return ctxt; +} + +int virNetSASLContextCheckIdentity(virNetSASLContextPtr ctxt, + const char *identity) +{ + const char *const*wildcards; + + /* If the list is not set, allow any DN. */ + wildcards = ctxt->usernameWhitelist; + if (!wildcards) + return 1; /* No ACL, allow all */ + + while (*wildcards) { +#if HAVE_FNMATCH_H + int ret = fnmatch (*wildcards, identity, 0); + if (ret == 0) /* Succesful match */ + return 1; + if (ret != FNM_NOMATCH) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("Malformed TLS whitelist regular expression '%s'"), + *wildcards); + return -1; + } +#else + if (STREQ(*wildcards, identity)) + return 1; +#endif + + wildcards++; + } + + /* Denied */ + VIR_ERROR(_("SASL client %s not allowed in whitelist"), identity); + + /* This is the most common error: make it informative. */ + virNetError(VIR_ERR_SYSTEM_ERROR, "%s", + _("Client's username is not on the list of allowed clients")); + return 0; +} + + +void virNetSASLContextRef(virNetSASLContextPtr ctxt) +{ + ctxt->refs++; +} + +void virNetSASLContextFree(virNetSASLContextPtr ctxt) +{ + if (!ctxt) + return; + + ctxt->refs--; + if (ctxt->refs > 0) + return; + + VIR_FREE(ctxt); +} + +virNetSASLSessionPtr virNetSASLSessionNewClient(virNetSASLContextPtr ctxt ATTRIBUTE_UNUSED, + const char *service, + const char *hostname, + const char *localAddr, + const char *remoteAddr, + const sasl_callback_t *cbs) +{ + virNetSASLSessionPtr sasl = NULL; + int err; + + if (VIR_ALLOC(sasl) < 0) { + virReportOOMError(); + goto cleanup; + } + + sasl->refs = 1; + /* Arbitrary size for amount of data we can encode in a single block */ + sasl->maxbufsize = 1 << 16; + + err = sasl_client_new(service, + hostname, + localAddr, + remoteAddr, + cbs, + SASL_SUCCESS_DATA, + &sasl->conn); + if (err != SASL_OK) { + virNetError(VIR_ERR_AUTH_FAILED, + _("Failed to create SASL client context: %d (%s)"), + err, sasl_errstring(err, NULL, NULL)); + goto cleanup; + } + + return sasl; + +cleanup: + virNetSASLSessionFree(sasl); + return NULL; +} + +virNetSASLSessionPtr virNetSASLSessionNewServer(virNetSASLContextPtr ctxt ATTRIBUTE_UNUSED, + const char *service, + const char *localAddr, + const char *remoteAddr) +{ + virNetSASLSessionPtr sasl = NULL; + int err; + + if (VIR_ALLOC(sasl) < 0) { + virReportOOMError(); + goto cleanup; + } + + sasl->refs = 1; + /* Arbitrary size for amount of data we can encode in a single block */ + sasl->maxbufsize = 1 << 16; + + err = sasl_server_new(service, + NULL, + NULL, + localAddr, + remoteAddr, + NULL, + SASL_SUCCESS_DATA, + &sasl->conn); + if (err != SASL_OK) { + virNetError(VIR_ERR_AUTH_FAILED, + _("Failed to create SASL client context: %d (%s)"), + err, sasl_errstring(err, NULL, NULL)); + goto cleanup; + } + + return sasl; + +cleanup: + virNetSASLSessionFree(sasl); + return NULL; +} + +void virNetSASLSessionRef(virNetSASLSessionPtr sasl) +{ + sasl->refs++; +} + +int virNetSASLSessionExtKeySize(virNetSASLSessionPtr sasl, + int ssf) +{ + int err; + + err = sasl_setprop(sasl->conn, SASL_SSF_EXTERNAL, &ssf); + if (err != SASL_OK) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("cannot set external SSF %d (%s)"), + err, sasl_errstring(err, NULL, NULL)); + return -1; + } + return 0; +} + +const char *virNetSASLSessionGetIdentity(virNetSASLSessionPtr sasl) +{ + const void *val; + int err; + + err = sasl_getprop(sasl->conn, SASL_USERNAME, &val); + if (err != SASL_OK) { + virNetError(VIR_ERR_AUTH_FAILED, + _("cannot query SASL username on connection %d (%s)"), + err, sasl_errstring(err, NULL, NULL)); + return NULL; + } + if (val == NULL) { + virNetError(VIR_ERR_AUTH_FAILED, + _("no client username was found")); + return NULL; + } + VIR_DEBUG("SASL client username %s", (const char *)val); + + return (const char*)val; +} + + +int virNetSASLSessionGetKeySize(virNetSASLSessionPtr sasl) +{ + int err; + int ssf; + const void *val; + err = sasl_getprop(sasl->conn, SASL_SSF, &val); + if (err != SASL_OK) { + virNetError(VIR_ERR_AUTH_FAILED, + _("cannot query SASL ssf on connection %d (%s)"), + err, sasl_errstring(err, NULL, NULL)); + return -1; + } + ssf = *(const int *)val; + return ssf; +} + +int virNetSASLSessionSecProps(virNetSASLSessionPtr sasl, + int minSSF, + int maxSSF, + bool allowAnonymous) +{ + sasl_security_properties_t secprops; + int err; + + VIR_DEBUG("minSSF=%d maxSSF=%d allowAnonymous=%d maxbufsize=%zu", + minSSF, maxSSF, allowAnonymous, sasl->maxbufsize); + + memset(&secprops, 0, sizeof secprops); + + secprops.min_ssf = minSSF; + secprops.max_ssf = maxSSF; + secprops.maxbufsize = sasl->maxbufsize; + secprops.security_flags = allowAnonymous ? 0 : + SASL_SEC_NOANONYMOUS | SASL_SEC_NOPLAINTEXT; + + err = sasl_setprop(sasl->conn, SASL_SEC_PROPS, &secprops); + if (err != SASL_OK) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("cannot set security props %d (%s)"), + err, sasl_errstring(err, NULL, NULL)); + return -1; + } + + return 0; +} + + +static int virNetSASLSessionUpdateBufSize(virNetSASLSessionPtr sasl) +{ + unsigned *maxbufsize; + int err; + + err = sasl_getprop(sasl->conn, SASL_MAXOUTBUF, (const void **)&maxbufsize); + if (err != SASL_OK) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("cannot get security props %d (%s)"), + err, sasl_errstring(err, NULL, NULL)); + return -1; + } + + VIR_DEBUG("Negotiated bufsize is %u vs requested size %zu", + *maxbufsize, sasl->maxbufsize); + sasl->maxbufsize = *maxbufsize; + return 0; +} + +char *virNetSASLSessionListMechanisms(virNetSASLSessionPtr sasl) +{ + const char *mechlist; + char *ret; + int err; + + err = sasl_listmech(sasl->conn, + NULL, /* Don't need to set user */ + "", /* Prefix */ + ",", /* Separator */ + "", /* Suffix */ + &mechlist, + NULL, + NULL); + if (err != SASL_OK) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("cannot list SASL mechanisms %d (%s)"), + err, sasl_errdetail(sasl->conn)); + return NULL; + } + if (!(ret = strdup(mechlist))) { + virReportOOMError(); + return NULL; + } + return ret; +} + + +int virNetSASLSessionClientStart(virNetSASLSessionPtr sasl, + const char *mechlist, + sasl_interact_t **prompt_need, + const char **clientout, + size_t *clientoutlen, + const char **mech) +{ + unsigned outlen = 0; + + VIR_DEBUG("sasl=%p mechlist=%s prompt_need=%p clientout=%p clientoutlen=%p mech=%p", + sasl, mechlist, prompt_need, clientout, clientoutlen, mech); + + int err = sasl_client_start(sasl->conn, + mechlist, + prompt_need, + clientout, + &outlen, + mech); + + *clientoutlen = outlen; + + switch (err) { + case SASL_OK: + if (virNetSASLSessionUpdateBufSize(sasl) < 0) + return -1; + return VIR_NET_SASL_COMPLETE; + case SASL_CONTINUE: + return VIR_NET_SASL_CONTINUE; + case SASL_INTERACT: + return VIR_NET_SASL_INTERACT; + + default: + virNetError(VIR_ERR_AUTH_FAILED, + _("Failed to start SASL negotiation: %d (%s)"), + err, sasl_errdetail(sasl->conn)); + return -1; + } +} + + +int virNetSASLSessionClientStep(virNetSASLSessionPtr sasl, + const char *serverin, + size_t serverinlen, + sasl_interact_t **prompt_need, + const char **clientout, + size_t *clientoutlen) +{ + unsigned inlen = serverinlen; + unsigned outlen = 0; + + VIR_DEBUG("sasl=%p serverin=%s serverinlen=%zu prompt_need=%p clientout=%p clientoutlen=%p", + sasl, serverin, serverinlen, prompt_need, clientout, clientoutlen); + + int err = sasl_client_step(sasl->conn, + serverin, + inlen, + prompt_need, + clientout, + &outlen); + *clientoutlen = outlen; + + switch (err) { + case SASL_OK: + if (virNetSASLSessionUpdateBufSize(sasl) < 0) + return -1; + return VIR_NET_SASL_COMPLETE; + case SASL_CONTINUE: + return VIR_NET_SASL_CONTINUE; + case SASL_INTERACT: + return VIR_NET_SASL_INTERACT; + + default: + virNetError(VIR_ERR_AUTH_FAILED, + _("Failed to step SASL negotiation: %d (%s)"), + err, sasl_errdetail(sasl->conn)); + return -1; + } +} + +int virNetSASLSessionServerStart(virNetSASLSessionPtr sasl, + const char *mechname, + const char *clientin, + size_t clientinlen, + const char **serverout, + size_t *serveroutlen) +{ + unsigned inlen = clientinlen; + unsigned outlen = 0; + int err = sasl_server_start(sasl->conn, + mechname, + clientin, + inlen, + serverout, + &outlen); + + *serveroutlen = outlen; + + switch (err) { + case SASL_OK: + if (virNetSASLSessionUpdateBufSize(sasl) < 0) + return -1; + return VIR_NET_SASL_COMPLETE; + case SASL_CONTINUE: + return VIR_NET_SASL_CONTINUE; + case SASL_INTERACT: + return VIR_NET_SASL_INTERACT; + + default: + virNetError(VIR_ERR_AUTH_FAILED, + _("Failed to start SASL negotiation: %d (%s)"), + err, sasl_errdetail(sasl->conn)); + return -1; + } +} + + +int virNetSASLSessionServerStep(virNetSASLSessionPtr sasl, + const char *clientin, + size_t clientinlen, + const char **serverout, + size_t *serveroutlen) +{ + unsigned inlen = clientinlen; + unsigned outlen = 0; + + int err = sasl_server_step(sasl->conn, + clientin, + inlen, + serverout, + &outlen); + + *serveroutlen = outlen; + + switch (err) { + case SASL_OK: + if (virNetSASLSessionUpdateBufSize(sasl) < 0) + return -1; + return VIR_NET_SASL_COMPLETE; + case SASL_CONTINUE: + return VIR_NET_SASL_CONTINUE; + case SASL_INTERACT: + return VIR_NET_SASL_INTERACT; + + default: + virNetError(VIR_ERR_AUTH_FAILED, + _("Failed to start SASL negotiation: %d (%s)"), + err, sasl_errdetail(sasl->conn)); + return -1; + } +} + +size_t virNetSASLSessionGetMaxBufSize(virNetSASLSessionPtr sasl) +{ + return sasl->maxbufsize; +} + +ssize_t virNetSASLSessionEncode(virNetSASLSessionPtr sasl, + const char *input, + size_t inputLen, + const char **output, + size_t *outputlen) +{ + unsigned inlen = inputLen; + unsigned outlen = 0; + int err; + + if (inputLen > sasl->maxbufsize) { + virReportSystemError(EINVAL, + _("SASL data length %zu too long, max %zu"), + inputLen, sasl->maxbufsize); + return -1; + } + + err = sasl_encode(sasl->conn, + input, + inlen, + output, + &outlen); + *outputlen = outlen; + + if (err != SASL_OK) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("failed to encode SASL data: %d (%s)"), + err, sasl_errstring(err, NULL, NULL)); + return -1; + } + return 0; +} + +ssize_t virNetSASLSessionDecode(virNetSASLSessionPtr sasl, + const char *input, + size_t inputLen, + const char **output, + size_t *outputlen) +{ + unsigned inlen = inputLen; + unsigned outlen = 0; + int err; + + if (inputLen > sasl->maxbufsize) { + virReportSystemError(EINVAL, + _("SASL data length %zu too long, max %zu"), + inputLen, sasl->maxbufsize); + return -1; + } + + err = sasl_decode(sasl->conn, + input, + inlen, + output, + &outlen); + *outputlen = outlen; + if (err != SASL_OK) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("failed to decode SASL data: %d (%s)"), + err, sasl_errstring(err, NULL, NULL)); + return -1; + } + return 0; +} + +void virNetSASLSessionFree(virNetSASLSessionPtr sasl) +{ + if (!sasl) + return; + + sasl->refs--; + if (sasl->refs > 0) + return; + + if (sasl->conn) + sasl_dispose(&sasl->conn); + + VIR_FREE(sasl); +} diff --git a/src/rpc/virnetsaslcontext.h b/src/rpc/virnetsaslcontext.h new file mode 100644 index 0000000..1ec6451 --- /dev/null +++ b/src/rpc/virnetsaslcontext.h @@ -0,0 +1,120 @@ +/* + * virnetsaslcontext.h: SASL encryption/auth handling + * + * Copyright (C) 2010 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + */ + +#ifndef __VIR_NET_CLIENT_SASL_CONTEXT_H__ +# define __VIR_NET_CLIENT_SASL_CONTEXT_H__ + +# include <sasl/sasl.h> + +# include <stdbool.h> +# include <sys/types.h> + +typedef struct _virNetSASLContext virNetSASLContext; +typedef virNetSASLContext *virNetSASLContextPtr; + +typedef struct _virNetSASLSession virNetSASLSession; +typedef virNetSASLSession *virNetSASLSessionPtr; + +enum { + VIR_NET_SASL_COMPLETE, + VIR_NET_SASL_CONTINUE, + VIR_NET_SASL_INTERACT, +}; + +virNetSASLContextPtr virNetSASLContextNewClient(void); +virNetSASLContextPtr virNetSASLContextNewServer(const char *const*usernameWhitelist); + +int virNetSASLContextCheckIdentity(virNetSASLContextPtr ctxt, + const char *identity); + +void virNetSASLContextRef(virNetSASLContextPtr sasl); +void virNetSASLContextFree(virNetSASLContextPtr sasl); + +virNetSASLSessionPtr virNetSASLSessionNewClient(virNetSASLContextPtr ctxt, + const char *service, + const char *hostname, + const char *localAddr, + const char *remoteAddr, + const sasl_callback_t *cbs); +virNetSASLSessionPtr virNetSASLSessionNewServer(virNetSASLContextPtr ctxt, + const char *service, + const char *localAddr, + const char *remoteAddr); + +char *virNetSASLSessionListMechanisms(virNetSASLSessionPtr sasl); + +void virNetSASLSessionRef(virNetSASLSessionPtr sasl); + +int virNetSASLSessionExtKeySize(virNetSASLSessionPtr sasl, + int ssf); + +int virNetSASLSessionGetKeySize(virNetSASLSessionPtr sasl); + +const char *virNetSASLSessionGetIdentity(virNetSASLSessionPtr sasl); + +int virNetSASLSessionSecProps(virNetSASLSessionPtr sasl, + int minSSF, + int maxSSF, + bool allowAnonymous); + +int virNetSASLSessionClientStart(virNetSASLSessionPtr sasl, + const char *mechlist, + sasl_interact_t **prompt_need, + const char **clientout, + size_t *clientoutlen, + const char **mech); + +int virNetSASLSessionClientStep(virNetSASLSessionPtr sasl, + const char *serverin, + size_t serverinlen, + sasl_interact_t **prompt_need, + const char **clientout, + size_t *clientoutlen); + +int virNetSASLSessionServerStart(virNetSASLSessionPtr sasl, + const char *mechname, + const char *clientin, + size_t clientinlen, + const char **serverout, + size_t *serveroutlen); + +int virNetSASLSessionServerStep(virNetSASLSessionPtr sasl, + const char *clientin, + size_t clientinlen, + const char **serverout, + size_t *serveroutlen); + +size_t virNetSASLSessionGetMaxBufSize(virNetSASLSessionPtr sasl); + +ssize_t virNetSASLSessionEncode(virNetSASLSessionPtr sasl, + const char *input, + size_t inputLen, + const char **output, + size_t *outputlen); + +ssize_t virNetSASLSessionDecode(virNetSASLSessionPtr sasl, + const char *input, + size_t inputLen, + const char **output, + size_t *outputlen); + +void virNetSASLSessionFree(virNetSASLSessionPtr sasl); + +#endif /* __VIR_NET_CLIENT_SASL_CONTEXT_H__ */ -- 1.7.4

This extends the basic virNetSocket APIs to allow them to have a handle to the TLS/SASL session objects, once established. This ensures that any data reads/writes are automagically passed through the TLS/SASL encryption layers if required. * src/rpc/virnetsocket.c, src/rpc/virnetsocket.h: Wire up SASL/TLS encryption --- src/rpc/virnetsocket.c | 276 +++++++++++++++++++++++++++++++++++++++++++++++- src/rpc/virnetsocket.h | 11 ++ 2 files changed, 284 insertions(+), 3 deletions(-) diff --git a/src/rpc/virnetsocket.c b/src/rpc/virnetsocket.c index a0eb431..a5ee861 100644 --- a/src/rpc/virnetsocket.c +++ b/src/rpc/virnetsocket.c @@ -27,6 +27,9 @@ #include <sys/socket.h> #include <unistd.h> #include <sys/wait.h> +#ifdef HAVE_NETINET_TCP_H +# include <netinet/tcp.h> +#endif #include "virnetsocket.h" #include "util.h" @@ -55,6 +58,19 @@ struct _virNetSocket { virSocketAddr remoteAddr; char *localAddrStr; char *remoteAddrStr; + + virNetTLSSessionPtr tlsSession; +#if HAVE_SASL + virNetSASLSessionPtr saslSession; + + const char *saslDecoded; + size_t saslDecodedLength; + size_t saslDecodedOffset; + + const char *saslEncoded; + size_t saslEncodedLength; + size_t saslEncodedOffset; +#endif }; @@ -394,7 +410,7 @@ error: } -#if HAVE_SYS_UN_H +#ifdef HAVE_SYS_UN_H int virNetSocketNewConnectUNIX(const char *path, bool spawnDaemon, const char *binary, @@ -610,6 +626,14 @@ void virNetSocketFree(virNetSocketPtr sock) unlink(sock->localAddr.data.un.sun_path); #endif + /* Make sure it can't send any more I/O during shutdown */ + if (sock->tlsSession) + virNetTLSSessionSetIOCallbacks(sock->tlsSession, NULL, NULL, NULL); + virNetTLSSessionFree(sock->tlsSession); +#if HAVE_SASL + virNetSASLSessionFree(sock->saslSession); +#endif + VIR_FORCE_CLOSE(sock->fd); VIR_FORCE_CLOSE(sock->errfd); @@ -695,14 +719,260 @@ const char *virNetSocketRemoteAddrString(virNetSocketPtr sock) return sock->remoteAddrStr; } -ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len) + +static ssize_t virNetSocketTLSSessionWrite(const char *buf, + size_t len, + void *opaque) { + virNetSocketPtr sock = opaque; + return write(sock->fd, buf, len); +} + + +static ssize_t virNetSocketTLSSessionRead(char *buf, + size_t len, + void *opaque) +{ + virNetSocketPtr sock = opaque; return read(sock->fd, buf, len); } + +void virNetSocketSetTLSSession(virNetSocketPtr sock, + virNetTLSSessionPtr sess) +{ + if (sock->tlsSession) + virNetTLSSessionFree(sock->tlsSession); + sock->tlsSession = sess; + virNetTLSSessionSetIOCallbacks(sess, + virNetSocketTLSSessionWrite, + virNetSocketTLSSessionRead, + sock); + virNetTLSSessionRef(sess); +} + + +#if HAVE_SASL +void virNetSocketSetSASLSession(virNetSocketPtr sock, + virNetSASLSessionPtr sess) +{ + if (sock->saslSession) + virNetSASLSessionFree(sock->saslSession); + sock->saslSession = sess; + virNetSASLSessionRef(sess); +} +#endif + + +bool virNetSocketHasCachedData(virNetSocketPtr sock ATTRIBUTE_UNUSED) +{ +#if HAVE_SASL + if (sock->saslDecoded) + return true; +#endif + return false; +} + + +static ssize_t virNetSocketReadWire(virNetSocketPtr sock, char *buf, size_t len) +{ + char *errout = NULL; + ssize_t ret; +reread: + if (sock->tlsSession && + virNetTLSSessionGetHandshakeStatus(sock->tlsSession) == + VIR_NET_TLS_HANDSHAKE_COMPLETE) { + ret = virNetTLSSessionRead(sock->tlsSession, buf, len); + } else { + ret = read(sock->fd, buf, len); + } + + if ((ret < 0) && (errno == EINTR)) + goto reread; + if ((ret < 0) && (errno == EAGAIN)) + return 0; + + if (ret <= 0 && + sock->errfd != -1 && + virFileReadLimFD(sock->errfd, 1024, &errout) >= 0 && + errout != NULL) { + size_t elen = strlen(errout); + if (elen && errout[elen-1] == '\n') + errout[elen-1] = '\0'; + } + + if (ret < 0) { + if (errout) + virReportSystemError(errno, + _("Cannot recv data: %s"), errout); + else + virReportSystemError(errno, "%s", + _("Cannot recv data")); + ret = -1; + } else if (ret == 0) { + if (errout) + virReportSystemError(EIO, + _("End of file while reading data: %s"), errout); + else + virReportSystemError(EIO, "%s", + _("End of file while reading data")); + ret = -1; + } + + VIR_FREE(errout); + return ret; +} + +static ssize_t virNetSocketWriteWire(virNetSocketPtr sock, const char *buf, size_t len) +{ + ssize_t ret; +rewrite: + if (sock->tlsSession && + virNetTLSSessionGetHandshakeStatus(sock->tlsSession) == + VIR_NET_TLS_HANDSHAKE_COMPLETE) { + ret = virNetTLSSessionWrite(sock->tlsSession, buf, len); + } else { + ret = write(sock->fd, buf, len); + } + + if (ret < 0) { + if (errno == EINTR) + goto rewrite; + if (errno == EAGAIN) + return 0; + + virReportSystemError(errno, "%s", + _("Cannot write data")); + return -1; + } + if (ret == 0) { + virReportSystemError(EIO, "%s", + _("End of file while writing data")); + return -1; + } + + return ret; +} + + +#if HAVE_SASL +static ssize_t virNetSocketReadSASL(virNetSocketPtr sock, char *buf, size_t len) +{ + ssize_t got; + + /* Need to read some more data off the wire */ + if (sock->saslDecoded == NULL) { + ssize_t encodedLen = virNetSASLSessionGetMaxBufSize(sock->saslSession); + char *encoded; + if (VIR_ALLOC_N(encoded, encodedLen) < 0) { + virReportOOMError(); + return -1; + } + encodedLen = virNetSocketReadWire(sock, encoded, encodedLen); + + if (encodedLen <= 0) { + VIR_FREE(encoded); + return encodedLen; + } + + if (virNetSASLSessionDecode(sock->saslSession, + encoded, encodedLen, + &sock->saslDecoded, &sock->saslDecodedLength) < 0) { + VIR_FREE(encoded); + return -1; + } + VIR_FREE(encoded); + + sock->saslDecodedOffset = 0; + } + + /* Some buffered decoded data to return now */ + got = sock->saslDecodedLength - sock->saslDecodedOffset; + + if (len > got) + len = got; + + memcpy(buf, sock->saslDecoded + sock->saslDecodedOffset, len); + sock->saslDecodedOffset += len; + + if (sock->saslDecodedOffset == sock->saslDecodedLength) { + sock->saslDecoded = NULL; + sock->saslDecodedOffset = sock->saslDecodedLength = 0; + } + + return len; +} + + +static ssize_t virNetSocketWriteSASL(virNetSocketPtr sock, const char *buf, size_t len) +{ + int ret; + size_t tosend = virNetSASLSessionGetMaxBufSize(sock->saslSession); + + /* SASL doesn't neccessarily let us send the whole + buffer at once */ + if (tosend > len) + tosend = len; + + /* Not got any pending encoded data, so we need to encode raw stuff */ + if (sock->saslEncoded == NULL) { + if (virNetSASLSessionEncode(sock->saslSession, + buf, tosend, + &sock->saslEncoded, + &sock->saslEncodedLength) < 0) + return -1; + + sock->saslEncodedOffset = 0; + } + + /* Send some of the encoded stuff out on the wire */ + ret = virNetSocketWriteWire(sock, + sock->saslEncoded + sock->saslEncodedOffset, + sock->saslEncodedLength - sock->saslEncodedOffset); + + if (ret <= 0) + return ret; /* -1 error, 0 == egain */ + + /* Note how much we sent */ + sock->saslEncodedOffset += ret; + + /* Sent all encoded, so update raw buffer to indicate completion */ + if (sock->saslEncodedOffset == sock->saslEncodedLength) { + sock->saslEncoded = NULL; + sock->saslEncodedOffset = sock->saslEncodedLength = 0; + + /* Mark as complete, so caller detects completion */ + return tosend; + } else { + /* Still have stuff pending in saslEncoded buffer. + * Pretend to caller that we didn't send any yet. + * The caller will then retry with same buffer + * shortly, which lets us finish saslEncoded. + */ + return 0; + } +} +#endif + + +ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len) +{ +#if HAVE_SASL + if (sock->saslSession) + return virNetSocketReadSASL(sock, buf, len); + else +#endif + return virNetSocketReadWire(sock, buf, len); +} + ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len) { - return write(sock->fd, buf, len); +#if HAVE_SASL + if (sock->saslSession) + return virNetSocketWriteSASL(sock, buf, len); + else +#endif + return virNetSocketWriteWire(sock, buf, len); } diff --git a/src/rpc/virnetsocket.h b/src/rpc/virnetsocket.h index c33b2e1..1be423b 100644 --- a/src/rpc/virnetsocket.h +++ b/src/rpc/virnetsocket.h @@ -26,6 +26,10 @@ # include "network.h" # include "command.h" +# include "virnettlscontext.h" +# ifdef HAVE_SASL +# include "virnetsaslcontext.h" +# endif typedef struct _virNetSocket virNetSocket; typedef virNetSocket *virNetSocketPtr; @@ -83,6 +87,13 @@ int virNetSocketSetBlocking(virNetSocketPtr sock, ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len); ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len); +void virNetSocketSetTLSSession(virNetSocketPtr sock, + virNetTLSSessionPtr sess); +# ifdef HAVE_SASL +void virNetSocketSetSASLSession(virNetSocketPtr sock, + virNetSASLSessionPtr sess); +# endif +bool virNetSocketHasCachedData(virNetSocketPtr sock); void virNetSocketFree(virNetSocketPtr sock); const char *virNetSocketLocalAddrString(virNetSocketPtr sock); -- 1.7.4

To facilitate creation of new daemons providing XDR RPC services, pull alot of the libvirtd daemon code into a set of reusable objects. * virNetServer: A server contains one or more services which accept incoming clients. It maintains the list of active clients. It has a list of RPC programs which can be used by clients. When clients produce a complete RPC message, the server passes this onto the corresponding program for handling, and queues any response back with the client. * virNetServerClient: Encapsulates a single client connection. All I/O for the client is handled, reading & writing RPC messages. * virNetServerProgram: Handles processing and dispatch of RPC method calls for a single RPC (program,version). Multiple programs can be registered with the server. * virNetServerService: Encapsulates socket(s) listening for new connections. Each service listens on a single host/port, but may have multiple sockets if on a dual IPv4/6 host. Each new daemon now merely has to define the list of RPC procedures & their handlers. It does not need to deal with any network related functionality at all. --- po/POTFILES.in | 3 + src/Makefile.am | 17 +- src/rpc/virnetserver.c | 716 +++++++++++++++++++++++++++++++ src/rpc/virnetserver.h | 80 ++++ src/rpc/virnetserverclient.c | 938 +++++++++++++++++++++++++++++++++++++++++ src/rpc/virnetserverclient.h | 106 +++++ src/rpc/virnetserverprogram.c | 456 ++++++++++++++++++++ src/rpc/virnetserverprogram.h | 107 +++++ src/rpc/virnetserverservice.c | 247 +++++++++++ src/rpc/virnetserverservice.h | 65 +++ 10 files changed, 2734 insertions(+), 1 deletions(-) create mode 100644 src/rpc/virnetserver.c create mode 100644 src/rpc/virnetserver.h create mode 100644 src/rpc/virnetserverclient.c create mode 100644 src/rpc/virnetserverclient.h create mode 100644 src/rpc/virnetserverprogram.c create mode 100644 src/rpc/virnetserverprogram.h create mode 100644 src/rpc/virnetserverservice.c create mode 100644 src/rpc/virnetserverservice.h diff --git a/po/POTFILES.in b/po/POTFILES.in index 53d63a8..c071874 100644 --- a/po/POTFILES.in +++ b/po/POTFILES.in @@ -67,6 +67,9 @@ src/remote/remote_driver.c src/rpc/virnetmessage.c src/rpc/virnetsaslcontext.c src/rpc/virnetsocket.c +src/rpc/virnetserver.c +src/rpc/virnetserverclient.c +src/rpc/virnetserverprogram.c src/rpc/virnettlscontext.c src/secret/secret_driver.c src/security/security_apparmor.c diff --git a/src/Makefile.am b/src/Makefile.am index 5d20d63..b0a96b8 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -1199,7 +1199,7 @@ libvirt_qemu_la_LIBADD = libvirt.la $(CYGWIN_EXTRA_LIBADD) EXTRA_DIST += $(LIBVIRT_QEMU_SYMBOL_FILE) -noinst_LTLIBRARIES += libvirt-net-rpc.la +noinst_LTLIBRARIES += libvirt-net-rpc.la libvirt-net-rpc-server.la libvirt_net_rpc_la_SOURCES = \ rpc/virnetmessage.h rpc/virnetmessage.c \ @@ -1226,6 +1226,21 @@ libvirt_net_rpc_la_LDFLAGS = \ libvirt_net_rpc_la_LIBADD = \ $(CYGWIN_EXTRA_LIBADD) +libvirt_net_rpc_server_la_SOURCES = \ + rpc/virnetserverprogram.h rpc/virnetserverprogram.c \ + rpc/virnetserverservice.h rpc/virnetserverservice.c \ + rpc/virnetserverclient.h rpc/virnetserverclient.c \ + rpc/virnetserver.h rpc/virnetserver.c +libvirt_net_rpc_server_la_CFLAGS = \ + $(AM_CFLAGS) +libvirt_net_rpc_server_la_LDFLAGS = \ + $(AM_LDFLAGS) \ + $(CYGWIN_EXTRA_LDFLAGS) \ + $(MINGW_EXTRA_LDFLAGS)l +libvirt_net_rpc_server_la_LIBADD = \ + $(CYGWIN_EXTRA_LIBADD) + + libexec_PROGRAMS = libexec_PROGRAMS += libvirt_iohelper diff --git a/src/rpc/virnetserver.c b/src/rpc/virnetserver.c new file mode 100644 index 0000000..5dbd7a9 --- /dev/null +++ b/src/rpc/virnetserver.c @@ -0,0 +1,716 @@ +/* + * virnetserver.c: generic network RPC server + * + * Copyright (C) 2006-2010 Red Hat, Inc. + * Copyright (C) 2006 Daniel P. Berrange + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#include <config.h> + +#include <unistd.h> +#include <string.h> + +#include "virnetserver.h" +#include "logging.h" +#include "memory.h" +#include "virterror_internal.h" +#include "threads.h" +#include "threadpool.h" +#include "util.h" +#include "files.h" +#include "event.h" + +#define VIR_FROM_THIS VIR_FROM_RPC + +#define virNetError(code, ...) \ + virReportErrorHelper(NULL, VIR_FROM_RPC, code, __FILE__, \ + __FUNCTION__, __LINE__, __VA_ARGS__) + +typedef struct _virNetServerSignal virNetServerSignal; +typedef virNetServerSignal *virNetServerSignalPtr; + +struct _virNetServerSignal { + struct sigaction oldaction; + int signum; + virNetServerSignalFunc func; + void *opaque; +}; + +typedef struct _virNetServerJob virNetServerJob; +typedef virNetServerJob *virNetServerJobPtr; + +struct _virNetServerJob { + virNetServerClientPtr client; + virNetMessagePtr msg; +}; + +struct _virNetServer { + int refs; + + virMutex lock; + + virThreadPoolPtr workers; + + bool privileged; + + size_t nsignals; + virNetServerSignalPtr *signals; + int sigread; + int sigwrite; + int sigwatch; + + size_t nservices; + virNetServerServicePtr *services; + + size_t nprograms; + virNetServerProgramPtr *programs; + + size_t nclients; + size_t nclients_max; + virNetServerClientPtr *clients; + + unsigned int quit :1; + + virNetTLSContextPtr tls; + + unsigned int autoShutdownTimeout; + virNetServerAutoShutdownFunc autoShutdownFunc; + void *autoShutdownOpaque; + + virNetServerClientInitHook clientInitHook; +}; + + +static void virNetServerLock(virNetServerPtr srv) +{ + virMutexLock(&srv->lock); +} + +static void virNetServerUnlock(virNetServerPtr srv) +{ + virMutexUnlock(&srv->lock); +} + + +static void virNetServerHandleJob(void *jobOpaque, void *opaque) +{ + virNetServerPtr srv = opaque; + virNetServerJobPtr job = jobOpaque; + virNetServerProgramPtr prog = NULL; + size_t i; + + virNetServerClientRef(job->client); + + virNetServerLock(srv); + VIR_DEBUG("server=%p client=%p message=%p", + srv, job->client, job->msg); + + for (i = 0 ; i < srv->nprograms ; i++) { + if (virNetServerProgramMatches(srv->programs[i], job->msg)) { + prog = srv->programs[i]; + break; + } + } + + if (!prog) { + VIR_DEBUG("Cannot find program %d version %d", + job->msg->header.prog, + job->msg->header.vers); + goto error; + } + + virNetServerProgramRef(prog); + virNetServerUnlock(srv); + + if (virNetServerProgramDispatch(prog, + srv, + job->client, + job->msg) < 0) + goto error; + + virNetServerLock(srv); + virNetServerProgramFree(prog); + virNetServerUnlock(srv); + virNetServerClientFree(job->client); + + VIR_FREE(job); + return; + +error: + virNetServerUnlock(srv); + if (prog) + virNetServerProgramFree(prog); + virNetMessageFree(job->msg); + virNetServerClientClose(job->client); + virNetServerClientFree(job->client); + VIR_FREE(job); +} + + +static int virNetServerDispatchNewMessage(virNetServerClientPtr client, + virNetMessagePtr msg, + void *opaque) +{ + virNetServerPtr srv = opaque; + virNetServerJobPtr job; + int ret; + + VIR_DEBUG("server=%p client=%p message=%p", + srv, client, msg); + + if (VIR_ALLOC(job) < 0) { + virReportOOMError(); + return -1; + } + + job->client = client; + job->msg = msg; + + virNetServerLock(srv); + if ((ret = virThreadPoolSendJob(srv->workers, job)) < 0) + VIR_FREE(job); + virNetServerUnlock(srv); + + return ret; +} + + +static int virNetServerDispatchNewClient(virNetServerServicePtr svc ATTRIBUTE_UNUSED, + virNetServerClientPtr client, + void *opaque) +{ + virNetServerPtr srv = opaque; + + virNetServerLock(srv); + + if (srv->nclients >= srv->nclients_max) { + virNetError(VIR_ERR_RPC, + _("Too many active clients (%zu), dropping connection from %s"), + srv->nclients_max, virNetServerClientRemoteAddrString(client)); + goto error; + } + + if (virNetServerClientInit(client) < 0) + goto error; + + if (srv->clientInitHook && + srv->clientInitHook(srv, client) < 0) + goto error; + + if (VIR_EXPAND_N(srv->clients, srv->nclients, 1) < 0) { + virReportOOMError(); + goto error; + } + srv->clients[srv->nclients-1] = client; + virNetServerClientRef(client); + + virNetServerClientSetDispatcher(client, + virNetServerDispatchNewMessage, + srv); + + virNetServerUnlock(srv); + return 0; + +error: + virNetServerUnlock(srv); + return -1; +} + + +static void virNetServerFatalSignal(int sig, siginfo_t * siginfo ATTRIBUTE_UNUSED, + void* context ATTRIBUTE_UNUSED) +{ + struct sigaction sig_action; + int origerrno; + + origerrno = errno; + virLogEmergencyDumpAll(sig); + + /* + * If the signal is fatal, avoid looping over this handler + * by desactivating it + */ +#ifdef SIGUSR2 + if (sig != SIGUSR2) { +#endif + sig_action.sa_handler = SIG_IGN; + sigaction(sig, &sig_action, NULL); +#ifdef SIGUSR2 + } +#endif + errno = origerrno; +} + + +virNetServerPtr virNetServerNew(size_t min_workers, + size_t max_workers, + size_t max_clients, + virNetServerClientInitHook clientInitHook) +{ + virNetServerPtr srv; + struct sigaction sig_action; + + if (VIR_ALLOC(srv) < 0) { + virReportOOMError(); + return NULL; + } + + srv->refs = 1; + + if (!(srv->workers = virThreadPoolNew(min_workers, max_workers, + virNetServerHandleJob, + srv))) + goto error; + + srv->nclients_max = max_clients; + srv->sigwrite = srv->sigread = -1; + srv->clientInitHook = clientInitHook; + srv->privileged = geteuid() == 0 ? true : false; + + if (virMutexInit(&srv->lock) < 0) { + virNetError(VIR_ERR_INTERNAL_ERROR, "%s", + _("cannot initialize mutex")); + goto error; + } + + if (virEventRegisterDefaultImpl() < 0) + goto error; + + memset(&sig_action, 0, sizeof(sig_action)); + sig_action.sa_handler = SIG_IGN; + sigaction(SIGPIPE, &sig_action, NULL); + + /* + * catch fatal errors to dump a log, also hook to USR2 for dynamic + * debugging purposes or testing + */ + sig_action.sa_sigaction = virNetServerFatalSignal; + sigaction(SIGFPE, &sig_action, NULL); + sigaction(SIGSEGV, &sig_action, NULL); + sigaction(SIGILL, &sig_action, NULL); + sigaction(SIGABRT, &sig_action, NULL); +#ifdef SIGBUS + sigaction(SIGBUS, &sig_action, NULL); +#endif +#ifdef SIGUSR2 + sigaction(SIGUSR2, &sig_action, NULL); +#endif + + VIR_DEBUG("srv=%p refs=%d", srv, srv->refs); + return srv; + +error: + virNetServerFree(srv); + return NULL; +} + + +void virNetServerRef(virNetServerPtr srv) +{ + virNetServerLock(srv); + srv->refs++; + VIR_DEBUG("srv=%p refs=%d", srv, srv->refs); + virNetServerUnlock(srv); +} + + +bool virNetServerIsPrivileged(virNetServerPtr srv) +{ + bool priv; + virNetServerLock(srv); + priv = srv->privileged; + virNetServerUnlock(srv); + return priv; +} + + +void virNetServerAutoShutdown(virNetServerPtr srv, + unsigned int timeout, + virNetServerAutoShutdownFunc func, + void *opaque) +{ + virNetServerLock(srv); + + srv->autoShutdownTimeout = timeout; + srv->autoShutdownFunc = func; + srv->autoShutdownOpaque = opaque; + + virNetServerUnlock(srv); +} + + +static sig_atomic_t sigErrors = 0; +static int sigLastErrno = 0; +static int sigWrite = -1; + +static void virNetServerSignalHandler(int sig, siginfo_t * siginfo, + void* context ATTRIBUTE_UNUSED) +{ + int origerrno; + int r; + + /* set the sig num in the struct */ + siginfo->si_signo = sig; + + origerrno = errno; + r = safewrite(sigWrite, siginfo, sizeof(*siginfo)); + if (r == -1) { + sigErrors++; + sigLastErrno = errno; + } + errno = origerrno; +} + +static void +virNetServerSignalEvent(int watch, + int fd ATTRIBUTE_UNUSED, + int events ATTRIBUTE_UNUSED, + void *opaque) { + virNetServerPtr srv = opaque; + siginfo_t siginfo; + int i; + + virNetServerLock(srv); + + if (saferead(srv->sigread, &siginfo, sizeof(siginfo)) != sizeof(siginfo)) { + virReportSystemError(errno, "%s", + _("Failed to read from signal pipe")); + virEventRemoveHandle(watch); + srv->sigwatch = -1; + goto cleanup; + } + + for (i = 0 ; i < srv->nsignals ; i++) { + if (siginfo.si_signo == srv->signals[i]->signum) { + virNetServerSignalFunc func = srv->signals[i]->func; + void *funcopaque = srv->signals[i]->opaque; + virNetServerUnlock(srv); + func(srv, &siginfo, funcopaque); + return; + } + } + + virNetError(VIR_ERR_INTERNAL_ERROR, + _("Unexpected signal received: %d"), siginfo.si_signo); + +cleanup: + virNetServerUnlock(srv); +} + +static int virNetServerSignalSetup(virNetServerPtr srv) +{ + int fds[2]; + + if (srv->sigwrite != -1) + return 0; + + if (pipe(fds) < 0) { + virReportSystemError(errno, "%s", + _("Unable to create signal pipe")); + return -1; + } + + if (virSetNonBlock(fds[0]) < 0 || + virSetNonBlock(fds[1]) < 0 || + virSetCloseExec(fds[0]) < 0 || + virSetCloseExec(fds[1]) < 0) { + virReportSystemError(errno, "%s", + _("Failed to setup pipe flags")); + goto error; + } + + if ((srv->sigwatch = virEventAddHandle(fds[0], + VIR_EVENT_HANDLE_READABLE, + virNetServerSignalEvent, + srv, NULL)) < 0) { + virNetError(VIR_ERR_INTERNAL_ERROR, "%s", + _("Failed to add signal handle watch")); + goto error; + } + + srv->sigread = fds[0]; + srv->sigwrite = fds[1]; + sigWrite = fds[1]; + + return 0; + +error: + VIR_FORCE_CLOSE(fds[0]); + VIR_FORCE_CLOSE(fds[1]); + return -1; +} + +int virNetServerAddSignalHandler(virNetServerPtr srv, + int signum, + virNetServerSignalFunc func, + void *opaque) +{ + virNetServerSignalPtr sigdata; + struct sigaction sig_action; + + virNetServerLock(srv); + + if (virNetServerSignalSetup(srv) < 0) + goto error; + + if (VIR_EXPAND_N(srv->signals, srv->nsignals, 1) < 0) + goto no_memory; + + if (VIR_ALLOC(sigdata) < 0) + goto no_memory; + + sigdata->signum = signum; + sigdata->func = func; + sigdata->opaque = opaque; + + memset(&sig_action, 0, sizeof(sig_action)); + sig_action.sa_sigaction = virNetServerSignalHandler; +#ifdef SA_SIGINFO + sig_action.sa_flags = SA_SIGINFO; +#endif + sigemptyset(&sig_action.sa_mask); + + sigaction(signum, &sig_action, &sigdata->oldaction); + + srv->signals[srv->nsignals-1] = sigdata; + + virNetServerUnlock(srv); + return 0; + +no_memory: + virReportOOMError(); +error: + VIR_FREE(sigdata); + virNetServerUnlock(srv); + return -1; +} + + + +int virNetServerAddService(virNetServerPtr srv, + virNetServerServicePtr svc) +{ + virNetServerLock(srv); + + if (VIR_EXPAND_N(srv->services, srv->nservices, 1) < 0) + goto no_memory; + + srv->services[srv->nservices-1] = svc; + virNetServerServiceRef(svc); + + virNetServerServiceSetDispatcher(svc, + virNetServerDispatchNewClient, + srv); + + virNetServerUnlock(srv); + return 0; + +no_memory: + virReportOOMError(); + virNetServerUnlock(srv); + return -1; +} + +int virNetServerAddProgram(virNetServerPtr srv, + virNetServerProgramPtr prog) +{ + virNetServerLock(srv); + + if (VIR_EXPAND_N(srv->programs, srv->nprograms, 1) < 0) + goto no_memory; + + srv->programs[srv->nprograms-1] = prog; + virNetServerProgramRef(prog); + + virNetServerUnlock(srv); + return 0; + +no_memory: + virReportOOMError(); + virNetServerUnlock(srv); + return -1; +} + +int virNetServerSetTLSContext(virNetServerPtr srv, + virNetTLSContextPtr tls) +{ + srv->tls = tls; + virNetTLSContextRef(tls); + return 0; +} + + +static void virNetServerAutoShutdownTimer(int timerid ATTRIBUTE_UNUSED, + void *opaque) { + virNetServerPtr srv = opaque; + + virNetServerLock(srv); + + if (srv->autoShutdownFunc(srv, srv->autoShutdownOpaque)) { + VIR_DEBUG0("Automatic shutdown triggered"); + srv->quit = 1; + } + + virNetServerUnlock(srv); +} + + +void virNetServerUpdateServices(virNetServerPtr srv, + bool enabled) +{ + int i; + + virNetServerLock(srv); + for (i = 0 ; i < srv->nservices ; i++) + virNetServerServiceToggle(srv->services[i], enabled); + + virNetServerUnlock(srv); +} + + +void virNetServerRun(virNetServerPtr srv) +{ + int timerid = -1; + int timerActive = 0; + int i; + + virNetServerLock(srv); + + if (srv->autoShutdownTimeout && + (timerid = virEventAddTimeout(-1, + virNetServerAutoShutdownTimer, + srv, NULL)) < 0) { + virNetError(VIR_ERR_INTERNAL_ERROR, "%s", + _("Failed to register shutdown timeout")); + goto cleanup; + } + + while (!srv->quit) { + /* A shutdown timeout is specified, so check + * if any drivers have active state, if not + * shutdown after timeout seconds + */ + if (srv->autoShutdownTimeout) { + if (timerActive) { + if (srv->clients) { + VIR_DEBUG("Deactivating shutdown timer %d", timerid); + virEventUpdateTimeout(timerid, -1); + timerActive = 0; + } + } else { + if (!srv->clients) { + VIR_DEBUG("Activating shutdown timer %d", timerid); + virEventUpdateTimeout(timerid, + srv->autoShutdownTimeout * 1000); + timerActive = 1; + } + } + } + + virNetServerUnlock(srv); + if (virEventRunDefaultImpl() < 0) { + virNetServerLock(srv); + VIR_DEBUG0("Loop iteration error, exiting"); + break; + } + virNetServerLock(srv); + + reprocess: + for (i = 0 ; i < srv->nclients ; i++) { + if (virNetServerClientWantClose(srv->clients[i])) + virNetServerClientClose(srv->clients[i]); + if (virNetServerClientIsClosed(srv->clients[i])) { + virNetServerClientFree(srv->clients[i]); + if (srv->nclients > 1) { + memmove(srv->clients + i, + srv->clients + i + 1, + sizeof(*srv->clients) * (srv->nclients - (i + 1))); + VIR_SHRINK_N(srv->clients, srv->nclients, 1); + } else { + VIR_FREE(srv->clients); + srv->nclients = 0; + } + + goto reprocess; + } + } + } + +cleanup: + virNetServerUnlock(srv); +} + + +void virNetServerQuit(virNetServerPtr srv) +{ + virNetServerLock(srv); + + srv->quit = 1; + + virNetServerUnlock(srv); +} + +void virNetServerFree(virNetServerPtr srv) +{ + int i; + + if (!srv) + return; + + virNetServerLock(srv); + VIR_DEBUG("srv=%p refs=%d", srv, srv->refs); + srv->refs--; + if (srv->refs > 0) { + virNetServerUnlock(srv); + return; + } + + for (i = 0 ; i < srv->nservices ; i++) + virNetServerServiceToggle(srv->services[i], false); + + virThreadPoolFree(srv->workers); + + for (i = 0 ; i < srv->nsignals ; i++) { + sigaction(srv->signals[i]->signum, &srv->signals[i]->oldaction, NULL); + VIR_FREE(srv->signals[i]); + } + VIR_FREE(srv->signals); + VIR_FORCE_CLOSE(srv->sigread); + VIR_FORCE_CLOSE(srv->sigwrite); + if (srv->sigwatch > 0) + virEventRemoveHandle(srv->sigwatch); + + for (i = 0 ; i < srv->nservices ; i++) + virNetServerServiceFree(srv->services[i]); + VIR_FREE(srv->services); + + for (i = 0 ; i < srv->nprograms ; i++) + virNetServerProgramFree(srv->programs[i]); + VIR_FREE(srv->programs); + + for (i = 0 ; i < srv->nclients ; i++) { + virNetServerClientClose(srv->clients[i]); + virNetServerClientFree(srv->clients[i]); + } + VIR_FREE(srv->clients); + + virNetServerUnlock(srv); + virMutexDestroy(&srv->lock); + VIR_FREE(srv); +} diff --git a/src/rpc/virnetserver.h b/src/rpc/virnetserver.h new file mode 100644 index 0000000..8b8b6a9 --- /dev/null +++ b/src/rpc/virnetserver.h @@ -0,0 +1,80 @@ +/* + * virnetserver.h: generic network RPC server + * + * Copyright (C) 2006-2010 Red Hat, Inc. + * Copyright (C) 2006 Daniel P. Berrange + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#ifndef __VIR_NET_SERVER_H__ +# define __VIR_NET_SERVER_H__ + +# include <stdbool.h> +# include <signal.h> + +# include "virnettlscontext.h" +# include "virnetserverprogram.h" +# include "virnetserverclient.h" +# include "virnetserverservice.h" + +typedef int (*virNetServerClientInitHook)(virNetServerPtr srv, + virNetServerClientPtr client); + +virNetServerPtr virNetServerNew(size_t min_workers, + size_t max_workers, + size_t max_clients, + virNetServerClientInitHook clientInitHook); + +typedef int (*virNetServerAutoShutdownFunc)(virNetServerPtr srv, void *opaque); + +void virNetServerRef(virNetServerPtr srv); + +bool virNetServerIsPrivileged(virNetServerPtr srv); + +void virNetServerAutoShutdown(virNetServerPtr srv, + unsigned int timeout, + virNetServerAutoShutdownFunc func, + void *opaque); + +typedef void (*virNetServerSignalFunc)(virNetServerPtr srv, siginfo_t *info, void *opaque); + +int virNetServerAddSignalHandler(virNetServerPtr srv, + int signum, + virNetServerSignalFunc func, + void *opaque); + +int virNetServerAddService(virNetServerPtr srv, + virNetServerServicePtr svc); + +int virNetServerAddProgram(virNetServerPtr srv, + virNetServerProgramPtr prog); + +int virNetServerSetTLSContext(virNetServerPtr srv, + virNetTLSContextPtr tls); + +void virNetServerUpdateServices(virNetServerPtr srv, + bool enabled); + +void virNetServerRun(virNetServerPtr srv); + +void virNetServerQuit(virNetServerPtr srv); + +void virNetServerFree(virNetServerPtr srv); + + +#endif diff --git a/src/rpc/virnetserverclient.c b/src/rpc/virnetserverclient.c new file mode 100644 index 0000000..2f42dba --- /dev/null +++ b/src/rpc/virnetserverclient.c @@ -0,0 +1,938 @@ +/* + * virnetserverclient.c: generic network RPC server client + * + * Copyright (C) 2006-2010 Red Hat, Inc. + * Copyright (C) 2006 Daniel P. Berrange + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#include <config.h> + +#if HAVE_SASL +# include <sasl/sasl.h> +#endif + +#include "virnetserverclient.h" + +#include "logging.h" +#include "virterror_internal.h" +#include "memory.h" +#include "threads.h" + +#define VIR_FROM_THIS VIR_FROM_RPC + +#define virNetError(code, ...) \ + virReportErrorHelper(NULL, VIR_FROM_RPC, code, __FILE__, \ + __FUNCTION__, __LINE__, __VA_ARGS__) + +/* Allow for filtering of incoming messages to a custom + * dispatch processing queue, instead of the workers. + * This allows for certain types of messages to be handled + * strictly "in order" + */ + +typedef struct _virNetServerClientFilter virNetServerClientFilter; +typedef virNetServerClientFilter *virNetServerClientFilterPtr; + +struct _virNetServerClientFilter { + int id; + virNetServerClientFilterFunc func; + void *opaque; + + virNetServerClientFilterPtr next; +}; + + +struct _virNetServerClient +{ + int refs; + bool wantClose; + virMutex lock; + virNetSocketPtr sock; + int auth; + bool readonly; + char *identity; + virNetTLSContextPtr tlsCtxt; + virNetTLSSessionPtr tls; +#if HAVE_SASL + virNetSASLSessionPtr sasl; +#endif + + /* Count of messages in the 'tx' queue, + * and the server worker pool queue + * ie RPC calls in progress. Does not count + * async events which are not used for + * throttling calculations */ + size_t nrequests; + size_t nrequests_max; + /* Zero or one messages being received. Zero if + * nrequests >= max_clients and throttling */ + virNetMessagePtr rx; + /* Zero or many messages waiting for transmit + * back to client, including async events */ + virNetMessagePtr tx; + + /* Filters to capture messages that would otherwise + * end up on the 'dx' queue */ + virNetServerClientFilterPtr filters; + int nextFilterID; + + virNetServerClientDispatchFunc dispatchFunc; + void *dispatchOpaque; + + void *privateData; + virNetServerClientFreeFunc privateDataFreeFunc; +}; + + +static void virNetServerClientDispatchEvent(virNetSocketPtr sock, int events, void *opaque); +static void virNetServerClientUpdateEvent(virNetServerClientPtr client); + +static void virNetServerClientLock(virNetServerClientPtr client) +{ + virMutexLock(&client->lock); +} + +static void virNetServerClientUnlock(virNetServerClientPtr client) +{ + virMutexUnlock(&client->lock); +} + + +/* + * @client: a locked client object + */ +static int +virNetServerClientCalculateHandleMode(virNetServerClientPtr client) { + int mode = 0; + + + VIR_DEBUG("tls=%p hs=%d, rx=%p tx=%p", + client->tls, + client->tls ? virNetTLSSessionGetHandshakeStatus(client->tls) : -1, + client->rx, + client->tx); + if (!client->sock || client->wantClose) + return 0; + + if (client->tls) { + switch (virNetTLSSessionGetHandshakeStatus(client->tls)) { + case VIR_NET_TLS_HANDSHAKE_RECVING: + mode |= VIR_EVENT_HANDLE_READABLE; + break; + case VIR_NET_TLS_HANDSHAKE_SENDING: + mode |= VIR_EVENT_HANDLE_WRITABLE; + break; + default: + case VIR_NET_TLS_HANDSHAKE_COMPLETE: + if (client->rx) + mode |= VIR_EVENT_HANDLE_READABLE; + if (client->tx) + mode |= VIR_EVENT_HANDLE_WRITABLE; + } + } else { + /* If there is a message on the rx queue then + * we're wanting more input */ + if (client->rx) + mode |= VIR_EVENT_HANDLE_READABLE; + + /* If there are one or more messages to send back to client, + then monitor for writability on socket */ + if (client->tx) + mode |= VIR_EVENT_HANDLE_WRITABLE; + } + VIR_DEBUG("mode=%d", mode); + return mode; +} + +/* + * @server: a locked or unlocked server object + * @client: a locked client object + */ +static int virNetServerClientRegisterEvent(virNetServerClientPtr client) +{ + int mode = virNetServerClientCalculateHandleMode(client); + + VIR_DEBUG("Registering client event callback %d", mode); + if (virNetSocketAddIOCallback(client->sock, + mode, + virNetServerClientDispatchEvent, + client) < 0) + return -1; + + return 0; +} + +/* + * @client: a locked client object + */ +static void virNetServerClientUpdateEvent(virNetServerClientPtr client) +{ + int mode; + + if (!client->sock) + return; + + mode = virNetServerClientCalculateHandleMode(client); + + virNetSocketUpdateIOCallback(client->sock, mode); +} + + +int virNetServerClientAddFilter(virNetServerClientPtr client, + virNetServerClientFilterFunc func, + void *opaque) +{ + virNetServerClientFilterPtr filter; + int ret = -1; + + virNetServerClientLock(client); + + if (VIR_ALLOC(filter) < 0) { + virReportOOMError(); + goto cleanup; + } + + filter->id = client->nextFilterID++; + filter->func = func; + filter->opaque = opaque; + + filter->next = client->filters; + client->filters = filter; + + ret = filter->id; + +cleanup: + virNetServerClientUnlock(client); + return ret; +} + + +void virNetServerClientRemoveFilter(virNetServerClientPtr client, + int filterID) +{ + virNetServerClientFilterPtr tmp, prev; + virNetServerClientLock(client); + + prev = NULL; + tmp = client->filters; + while (tmp) { + if (tmp->id == filterID) { + if (prev) + prev->next = tmp->next; + else + client->filters = tmp->next; + + VIR_FREE(tmp); + break; + } + tmp = tmp->next; + } + + virNetServerClientUnlock(client); +} + + +/* Check the client's access. */ +static int +virNetServerClientCheckAccess(virNetServerClientPtr client) +{ + virNetMessagePtr confirm; + + /* Verify client certificate. */ + if (virNetTLSContextCheckCertificate(client->tlsCtxt, client->tls) < 0) + return -1; + + if (client->tx) { + VIR_INFO0(_("client had unexpected data pending tx after access check")); + return -1; + } + + if (!(confirm = virNetMessageNew())) + return -1; + + /* Checks have succeeded. Write a '\1' byte back to the client to + * indicate this (otherwise the socket is abruptly closed). + * (NB. The '\1' byte is sent in an encrypted record). + */ + confirm->bufferLength = 1; + confirm->bufferOffset = 0; + confirm->buffer[0] = '\1'; + + client->tx = confirm; + + return 0; +} + + +virNetServerClientPtr virNetServerClientNew(virNetSocketPtr sock, + int auth, + bool readonly, + virNetTLSContextPtr tls) +{ + virNetServerClientPtr client; + + VIR_DEBUG("sock=%p auth=%d tls=%p", sock, auth, tls); + + if (VIR_ALLOC(client) < 0) { + virReportOOMError(); + return NULL; + } + + if (virMutexInit(&client->lock) < 0) + goto error; + + client->refs = 1; + client->sock = sock; + client->auth = auth; + client->readonly = readonly; + client->tlsCtxt = tls; + client->nrequests_max = 10; /* XXX */ + + if (tls) + virNetTLSContextRef(tls); + + /* Prepare one for packet receive */ + if (!(client->rx = virNetMessageNew())) + goto error; + client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX; + client->nrequests = 1; + + VIR_DEBUG("client=%p refs=%d", client, client->refs); + + return client; + +error: + /* XXX ref counting is better than this */ + client->sock = NULL; /* Caller owns 'sock' upon failure */ + virNetServerClientFree(client); + return NULL; +} + +void virNetServerClientRef(virNetServerClientPtr client) +{ + virNetServerClientLock(client); + client->refs++; + VIR_DEBUG("client=%p refs=%d", client, client->refs); + virNetServerClientUnlock(client); +} + + +int virNetServerClientGetAuth(virNetServerClientPtr client) +{ + int auth; + virNetServerClientLock(client); + auth = client->auth; + virNetServerClientUnlock(client); + return auth; +} + +bool virNetServerClientGetReadonly(virNetServerClientPtr client) +{ + bool readonly; + virNetServerClientLock(client); + readonly = client->readonly; + virNetServerClientUnlock(client); + return readonly; +} + + +bool virNetServerClientHasTLSSession(virNetServerClientPtr client) +{ + bool has; + virNetServerClientLock(client); + has = client->tls ? true : false; + virNetServerClientUnlock(client); + return has; +} + +int virNetServerClientGetTLSKeySize(virNetServerClientPtr client) +{ + int size = 0; + virNetServerClientLock(client); + if (client->tls) + size = virNetTLSSessionGetKeySize(client->tls); + virNetServerClientUnlock(client); + return size; +} + +int virNetServerClientGetFD(virNetServerClientPtr client) +{ + int fd = 0; + virNetServerClientLock(client); + fd = virNetSocketGetFD(client->sock); + virNetServerClientUnlock(client); + return fd; +} + +int virNetServerClientGetLocalIdentity(virNetServerClientPtr client, + uid_t *uid, pid_t *pid) +{ + int ret; + virNetServerClientLock(client); + ret = virNetSocketGetLocalIdentity(client->sock, uid, pid); + virNetServerClientUnlock(client); + return ret; +} + +bool virNetServerClientIsSecure(virNetServerClientPtr client) +{ + bool secure = false; + virNetServerClientLock(client); + if (client->tls) + secure = true; +#if HAVE_SASL + if (client->sasl) + secure = true; +#endif + if (virNetSocketIsLocal(client->sock)) + secure = true; + virNetServerClientUnlock(client); + return secure; +} + + +#if HAVE_SASL +void virNetServerClientSetSASLSession(virNetServerClientPtr client, + virNetSASLSessionPtr sasl) +{ + /* We don't set the sasl session on the socket here + * because we need to send out the auth confirmation + * in the clear. Only once we complete the next 'tx' + * operation do we switch to SASL mode + */ + virNetServerClientLock(client); + client->sasl = sasl; + virNetSASLSessionRef(sasl); + virNetServerClientUnlock(client); +} +#endif + + +int virNetServerClientSetIdentity(virNetServerClientPtr client, + const char *identity) +{ + int ret = -1; + virNetServerClientLock(client); + if (!(client->identity = strdup(identity))) { + virReportOOMError(); + goto error; + } + ret = 0; + +error: + virNetServerClientUnlock(client); + return ret; +} + +const char *virNetServerClientGetIdentity(virNetServerClientPtr client) +{ + const char *identity; + virNetServerClientLock(client); + identity = client->identity; + virNetServerClientLock(client); + return identity; +} + +void virNetServerClientSetPrivateData(virNetServerClientPtr client, + void *opaque, + virNetServerClientFreeFunc ff) +{ + virNetServerClientLock(client); + + if (client->privateData && + client->privateDataFreeFunc) + client->privateDataFreeFunc(client->privateData); + + client->privateData = opaque; + client->privateDataFreeFunc = ff; + + virNetServerClientUnlock(client); +} + + +void *virNetServerClientGetPrivateData(virNetServerClientPtr client) +{ + void *data; + virNetServerClientLock(client); + data = client->privateData; + virNetServerClientUnlock(client); + return data; +} + + +void virNetServerClientSetDispatcher(virNetServerClientPtr client, + virNetServerClientDispatchFunc func, + void *opaque) +{ + virNetServerClientLock(client); + client->dispatchFunc = func; + client->dispatchOpaque = opaque; + virNetServerClientUnlock(client); +} + + +const char *virNetServerClientLocalAddrString(virNetServerClientPtr client) +{ + return virNetSocketLocalAddrString(client->sock); +} + + +const char *virNetServerClientRemoteAddrString(virNetServerClientPtr client) +{ + return virNetSocketRemoteAddrString(client->sock); +} + + +void virNetServerClientFree(virNetServerClientPtr client) +{ + if (!client) + return; + + virNetServerClientLock(client); + VIR_DEBUG("client=%p refs=%d", client, client->refs); + + client->refs--; + if (client->refs > 0) { + virNetServerClientUnlock(client); + return; + } + + if (client->privateData && + client->privateDataFreeFunc) + client->privateDataFreeFunc(client->privateData); + + VIR_FREE(client->identity); +#if HAVE_SASL + virNetSASLSessionFree(client->sasl); +#endif + virNetTLSSessionFree(client->tls); + virNetTLSContextFree(client->tlsCtxt); + virNetSocketFree(client->sock); + virNetServerClientUnlock(client); + virMutexDestroy(&client->lock); + VIR_FREE(client); +} + + +/* + * + * We don't free stuff here, merely disconnect the client's + * network socket & resources. + * + * Full free of the client is done later in a safe point + * where it can be guaranteed it is no longer in use + */ +void virNetServerClientClose(virNetServerClientPtr client) +{ + virNetServerClientLock(client); + VIR_DEBUG("client=%p refs=%d", client, client->refs); + if (!client->sock) { + virNetServerClientUnlock(client); + return; + } + + /* Do now, even though we don't close the socket + * until end, to ensure we don't get invoked + * again due to tls shutdown */ + if (client->sock) + virNetSocketRemoveIOCallback(client->sock); + + if (client->tls) { + virNetTLSSessionFree(client->tls); + client->tls = NULL; + } + if (client->sock) { + virNetSocketFree(client->sock); + client->sock = NULL; + } + + while (client->rx) { + virNetMessagePtr msg + = virNetMessageQueueServe(&client->rx); + virNetMessageFree(msg); + } + while (client->tx) { + virNetMessagePtr msg + = virNetMessageQueueServe(&client->tx); + virNetMessageFree(msg); + } + + virNetServerClientUnlock(client); +} + + +bool virNetServerClientIsClosed(virNetServerClientPtr client) +{ + bool closed; + virNetServerClientLock(client); + closed = client->sock == NULL ? true : false; + virNetServerClientUnlock(client); + return closed; +} + +void virNetServerClientMarkClose(virNetServerClientPtr client) +{ + virNetServerClientLock(client); + client->wantClose = true; + virNetServerClientUnlock(client); +} + +bool virNetServerClientWantClose(virNetServerClientPtr client) +{ + bool wantClose; + virNetServerClientLock(client); + wantClose = client->wantClose; + virNetServerClientUnlock(client); + return wantClose; +} + + +int virNetServerClientInit(virNetServerClientPtr client) +{ + virNetServerClientLock(client); + + if (!client->tlsCtxt) { + /* Plain socket, so prepare to read first message */ + if (virNetServerClientRegisterEvent(client) < 0) + goto error; + } else { + int ret; + + if (!(client->tls = virNetTLSSessionNew(client->tlsCtxt, + NULL))) + goto error; + + virNetSocketSetTLSSession(client->sock, + client->tls); + + /* Begin the TLS handshake. */ + ret = virNetTLSSessionHandshake(client->tls); + if (ret == 0) { + /* Unlikely, but ... Next step is to check the certificate. */ + if (virNetServerClientCheckAccess(client) < 0) + goto error; + + /* Handshake & cert check OK, so prepare to read first message */ + if (virNetServerClientRegisterEvent(client) < 0) + goto error; + } else if (ret > 0) { + /* Most likely, need to do more handshake data */ + if (virNetServerClientRegisterEvent(client) < 0) + goto error; + } else { + goto error; + } + } + + virNetServerClientUnlock(client); + return 0; + +error: + client->wantClose = true; + virNetServerClientUnlock(client); + return -1; +} + + + +/* + * Read data into buffer using wire decoding (plain or TLS) + * + * Returns: + * -1 on error or EOF + * 0 on EAGAIN + * n number of bytes + */ +static ssize_t virNetServerClientRead(virNetServerClientPtr client) +{ + ssize_t ret; + + if (client->rx->bufferLength <= client->rx->bufferOffset) { + virNetError(VIR_ERR_RPC, + _("unexpected zero/negative length request %lld"), + (long long int)(client->rx->bufferLength - client->rx->bufferOffset)); + client->wantClose = true; + return -1; + } + + ret = virNetSocketRead(client->sock, + client->rx->buffer + client->rx->bufferOffset, + client->rx->bufferLength - client->rx->bufferOffset); + + if (ret <= 0) + return ret; + + client->rx->bufferOffset += ret; + return ret; +} + + +/* + * Read data until we get a complete message to process + */ +static void virNetServerClientDispatchRead(virNetServerClientPtr client) +{ +readmore: + if (virNetServerClientRead(client) < 0) { + client->wantClose = true; + return; /* Error */ + } + + if (client->rx->bufferOffset < client->rx->bufferLength) + return; /* Still not read enough */ + + /* Either done with length word header */ + if (client->rx->bufferLength == VIR_NET_MESSAGE_LEN_MAX) { + if (virNetMessageDecodeLength(client->rx) < 0) + return; + + virNetServerClientUpdateEvent(client); + + /* Try and read payload immediately instead of going back + into poll() because chances are the data is already + waiting for us */ + goto readmore; + } else { + /* Grab the completed message */ + virNetMessagePtr msg = virNetMessageQueueServe(&client->rx); + virNetServerClientFilterPtr filter; + + /* Decode the header so we can use it for routing decisions */ + if (virNetMessageDecodeHeader(msg) < 0) { + virNetMessageFree(msg); + client->wantClose = true; + return; + } + + /* Maybe send off for queue against a filter */ + filter = client->filters; + while (filter) { + int ret = filter->func(client, msg, filter->opaque); + if (ret < 0 || ret > 0) { + virNetMessageFree(msg); + msg = NULL; + if (ret < 0) + client->wantClose = true; + break; + } + + filter = filter->next; + } + + /* Send off to for normal dispatch to workers */ + if (msg) { + if (!client->dispatchFunc || + client->dispatchFunc(client, msg, client->dispatchOpaque) < 0) { + virNetMessageFree(msg); + client->wantClose = true; + return; + } + } + + /* Possibly need to create another receive buffer */ + if (client->nrequests < client->nrequests_max) { + if (!(client->rx = virNetMessageNew())) { + client->wantClose = true; + } + client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX; + client->nrequests++; + } + virNetServerClientUpdateEvent(client); + } +} + + +/* + * Send client->tx using no encoding + * + * Returns: + * -1 on error or EOF + * 0 on EAGAIN + * n number of bytes + */ +static ssize_t virNetServerClientWrite(virNetServerClientPtr client) +{ + ssize_t ret; + + if (client->tx->bufferLength < client->tx->bufferOffset) { + virNetError(VIR_ERR_RPC, + _("unexpected zero/negative length request %lld"), + (long long int)(client->tx->bufferLength - client->tx->bufferOffset)); + client->wantClose = true; + return -1; + } + + if (client->tx->bufferLength == client->tx->bufferOffset) + return 1; + + ret = virNetSocketWrite(client->sock, + client->tx->buffer + client->tx->bufferOffset, + client->tx->bufferLength - client->tx->bufferOffset); + if (ret <= 0) + return ret; /* -1 error, 0 = egain */ + + client->tx->bufferOffset += ret; + return ret; +} + + +/* + * Process all queued client->tx messages until + * we would block on I/O + */ +static void +virNetServerClientDispatchWrite(virNetServerClientPtr client) +{ + while (client->tx) { + ssize_t ret; + + ret = virNetServerClientWrite(client); + if (ret < 0) { + client->wantClose = true; + return; + } + if (ret == 0) + return; /* Would block on write EAGAIN */ + + if (client->tx->bufferOffset == client->tx->bufferLength) { + virNetMessagePtr msg; +#if HAVE_SASL + /* Completed this 'tx' operation, so now read for all + * future rx/tx to be under a SASL SSF layer + */ + if (client->sasl) { + virNetSocketSetSASLSession(client->sock, client->sasl); + virNetSASLSessionFree(client->sasl); + client->sasl = NULL; + } +#endif + + /* Get finished msg from head of tx queue */ + msg = virNetMessageQueueServe(&client->tx); + + if (msg->header.type == VIR_NET_REPLY) { + client->nrequests--; + /* See if the recv queue is currently throttled */ + if (!client->rx && + client->nrequests < client->nrequests_max) { + /* Ready to recv more messages */ + client->rx = msg; + client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX; + msg = NULL; + client->nrequests++; + } + } + + virNetMessageFree(msg); + + virNetServerClientUpdateEvent(client); + } + } +} + +static void +virNetServerClientDispatchHandshake(virNetServerClientPtr client) +{ + int ret; + /* Continue the handshake. */ + ret = virNetTLSSessionHandshake(client->tls); + if (ret == 0) { + /* Finished. Next step is to check the certificate. */ + if (virNetServerClientCheckAccess(client) < 0) + client->wantClose = true; + else + virNetServerClientUpdateEvent(client); + } else if (ret > 0) { + /* Carry on waiting for more handshake. Update + the events just in case handshake data flow + direction has changed */ + virNetServerClientUpdateEvent (client); + } else { + /* Fatal error in handshake */ + client->wantClose = true; + } +} + +static void +virNetServerClientDispatchEvent(virNetSocketPtr sock, int events, void *opaque) +{ + virNetServerClientPtr client = opaque; + + virNetServerClientLock(client); + + if (client->sock != sock) { + virNetSocketRemoveIOCallback(sock); + virNetServerClientUnlock(client); + return; + } + + if (events & (VIR_EVENT_HANDLE_WRITABLE | + VIR_EVENT_HANDLE_READABLE)) { + if (client->tls && + virNetTLSSessionGetHandshakeStatus(client->tls) != + VIR_NET_TLS_HANDSHAKE_COMPLETE) { + virNetServerClientDispatchHandshake(client); + } else { + if (events & VIR_EVENT_HANDLE_WRITABLE) + virNetServerClientDispatchWrite(client); + if (events & VIR_EVENT_HANDLE_READABLE) + virNetServerClientDispatchRead(client); + } + } + + /* NB, will get HANGUP + READABLE at same time upon + * disconnect */ + if (events & (VIR_EVENT_HANDLE_ERROR | + VIR_EVENT_HANDLE_HANGUP)) + client->wantClose = true; + + virNetServerClientUnlock(client); +} + + +int virNetServerClientSendMessage(virNetServerClientPtr client, + virNetMessagePtr msg) +{ + int ret = -1; + VIR_DEBUG("msg=%p proc=%d len=%zu offset=%zu", + msg, msg->header.proc, + msg->bufferLength, msg->bufferOffset); + virNetServerClientLock(client); + + if (client->sock && !client->wantClose) { + virNetMessageQueuePush(&client->tx, msg); + + virNetServerClientUpdateEvent(client); + ret = 0; + } + + virNetServerClientUnlock(client); + return ret; +} + + +bool virNetServerClientNeedAuth(virNetServerClientPtr client) +{ + bool need = false; + virNetServerClientLock(client); + if (client->auth && !client->identity) + need = true; + virNetServerClientUnlock(client); + return need; +} diff --git a/src/rpc/virnetserverclient.h b/src/rpc/virnetserverclient.h new file mode 100644 index 0000000..e573d3a --- /dev/null +++ b/src/rpc/virnetserverclient.h @@ -0,0 +1,106 @@ +/* + * virnetserverclient.h: generic network RPC server client + * + * Copyright (C) 2006-2010 Red Hat, Inc. + * Copyright (C) 2006 Daniel P. Berrange + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#ifndef __VIR_NET_SERVER_CLIENT_H__ +# define __VIR_NET_SERVER_CLIENT_H__ + +# include "virnetsocket.h" +# include "virnetmessage.h" + +typedef struct _virNetServerClient virNetServerClient; +typedef virNetServerClient *virNetServerClientPtr; + +typedef int (*virNetServerClientDispatchFunc)(virNetServerClientPtr client, + virNetMessagePtr msg, + void *opaque); + +typedef int (*virNetServerClientFilterFunc)(virNetServerClientPtr client, + virNetMessagePtr msg, + void *opaque); + +virNetServerClientPtr virNetServerClientNew(virNetSocketPtr sock, + int auth, + bool readonly, + virNetTLSContextPtr tls); + +int virNetServerClientAddFilter(virNetServerClientPtr client, + virNetServerClientFilterFunc func, + void *opaque); + +void virNetServerClientRemoveFilter(virNetServerClientPtr client, + int filterID); + +int virNetServerClientGetAuth(virNetServerClientPtr client); +bool virNetServerClientGetReadonly(virNetServerClientPtr client); + +bool virNetServerClientHasTLSSession(virNetServerClientPtr client); +int virNetServerClientGetTLSKeySize(virNetServerClientPtr client); + +#ifdef HAVE_SASL +void virNetServerClientSetSASLSession(virNetServerClientPtr client, + virNetSASLSessionPtr sasl); +#endif + +int virNetServerClientGetFD(virNetServerClientPtr client); + +bool virNetServerClientIsSecure(virNetServerClientPtr client); + +int virNetServerClientSetIdentity(virNetServerClientPtr client, + const char *identity); +const char *virNetServerClientGetIdentity(virNetServerClientPtr client); + +int virNetServerClientGetLocalIdentity(virNetServerClientPtr client, + uid_t *uid, pid_t *pid); + +void virNetServerClientRef(virNetServerClientPtr client); + +typedef void (*virNetServerClientFreeFunc)(void *data); + +void virNetServerClientSetPrivateData(virNetServerClientPtr client, + void *opaque, + virNetServerClientFreeFunc ff); +void *virNetServerClientGetPrivateData(virNetServerClientPtr client); + +void virNetServerClientSetDispatcher(virNetServerClientPtr client, + virNetServerClientDispatchFunc func, + void *opaque); +void virNetServerClientClose(virNetServerClientPtr client); + +bool virNetServerClientIsClosed(virNetServerClientPtr client); +void virNetServerClientMarkClose(virNetServerClientPtr client); +bool virNetServerClientWantClose(virNetServerClientPtr client); + +int virNetServerClientInit(virNetServerClientPtr client); + +const char *virNetServerClientLocalAddrString(virNetServerClientPtr client); +const char *virNetServerClientRemoteAddrString(virNetServerClientPtr client); + +int virNetServerClientSendMessage(virNetServerClientPtr client, + virNetMessagePtr msg); + +bool virNetServerClientNeedAuth(virNetServerClientPtr client); + +void virNetServerClientFree(virNetServerClientPtr client); + + +#endif /* __VIR_NET_SERVER_CLIENT_H__ */ diff --git a/src/rpc/virnetserverprogram.c b/src/rpc/virnetserverprogram.c new file mode 100644 index 0000000..cbd4262 --- /dev/null +++ b/src/rpc/virnetserverprogram.c @@ -0,0 +1,456 @@ +/* + * virnetserverprogram.c: generic network RPC server program + * + * Copyright (C) 2006-2010 Red Hat, Inc. + * Copyright (C) 2006 Daniel P. Berrange + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#include <config.h> + +#include "virnetserverprogram.h" +#include "virnetserverclient.h" + +#include "memory.h" +#include "virterror_internal.h" +#include "logging.h" + +#define VIR_FROM_THIS VIR_FROM_RPC + +#define virNetError(code, ...) \ + virReportErrorHelper(NULL, VIR_FROM_RPC, code, __FILE__, \ + __FUNCTION__, __LINE__, __VA_ARGS__) + +struct _virNetServerProgram { + int refs; + + unsigned program; + unsigned version; + virNetServerProgramProcPtr procs; + size_t nprocs; +}; + +virNetServerProgramPtr virNetServerProgramNew(unsigned program, + unsigned version, + virNetServerProgramProcPtr procs, + size_t nprocs) +{ + virNetServerProgramPtr prog; + + if (VIR_ALLOC(prog) < 0) { + virReportOOMError(); + return NULL; + } + + prog->refs = 1; + prog->program = program; + prog->version = version; + prog->procs = procs; + prog->nprocs = nprocs; + + VIR_DEBUG("prog=%p refs=%d", prog, prog->refs); + + return prog; +} + + +int virNetServerProgramGetID(virNetServerProgramPtr prog) +{ + return prog->program; +} + + +int virNetServerProgramGetVersion(virNetServerProgramPtr prog) +{ + return prog->version; +} + + +void virNetServerProgramRef(virNetServerProgramPtr prog) +{ + prog->refs++; + VIR_DEBUG("prog=%p refs=%d", prog, prog->refs); +} + + +int virNetServerProgramMatches(virNetServerProgramPtr prog, + virNetMessagePtr msg) +{ + if (prog->program == msg->header.prog && + prog->version == msg->header.vers) + return 1; + return 0; +} + + +static virNetServerProgramProcPtr virNetServerProgramGetProc(virNetServerProgramPtr prog, + int procedure) +{ + if (procedure < 0) + return NULL; + if (procedure >= prog->nprocs) + return NULL; + + return &prog->procs[procedure]; +} + + +static int +virNetServerProgramSendError(virNetServerProgramPtr prog, + virNetServerClientPtr client, + virNetMessagePtr msg, + virNetMessageErrorPtr rerr, + int procedure, + int type, + int serial) +{ + VIR_DEBUG("prog=%d ver=%d proc=%d type=%d serial=%d msg=%p rerr=%p", + prog->program, prog->version, procedure, type, serial, msg, rerr); + + virNetMessageSaveError(rerr); + + /* Return header. */ + msg->header.prog = prog->program; + msg->header.vers = prog->version; + msg->header.proc = procedure; + msg->header.type = type; + msg->header.serial = serial; + msg->header.status = VIR_NET_ERROR; + + if (virNetMessageEncodeHeader(msg) < 0) + goto error; + + if (virNetMessageEncodePayload(msg, (xdrproc_t)xdr_virNetMessageError, rerr) < 0) + goto error; + xdr_free((xdrproc_t)xdr_virNetMessageError, (void*)rerr); + + /* Put reply on end of tx queue to send out */ + if (virNetServerClientSendMessage(client, msg) < 0) + return -1; + + return 0; + +error: + VIR_WARN("Failed to serialize remote error '%p'", rerr); + xdr_free((xdrproc_t)xdr_virNetMessageError, (void*)rerr); + return -1; +} + + +/* + * @client: the client to send the error to + * @req: the message this error is in reply to + * + * Send an error message to the client + * + * Returns 0 if the error was sent, -1 upon fatal error + */ +int +virNetServerProgramSendReplyError(virNetServerProgramPtr prog, + virNetServerClientPtr client, + virNetMessagePtr msg, + virNetMessageErrorPtr rerr, + virNetMessageHeaderPtr req) +{ + /* + * For data streams, errors are sent back as data streams + * For method calls, errors are sent back as method replies + */ + return virNetServerProgramSendError(prog, + client, + msg, + rerr, + req->proc, + req->type == VIR_NET_STREAM ? VIR_NET_STREAM : VIR_NET_REPLY, + req->serial); +} + + +int virNetServerProgramSendStreamError(virNetServerProgramPtr prog, + virNetServerClientPtr client, + virNetMessagePtr msg, + virNetMessageErrorPtr rerr, + int procedure, + int serial) +{ + return virNetServerProgramSendError(prog, + client, + msg, + rerr, + procedure, + VIR_NET_STREAM, + serial); +} + + +static int +virNetServerProgramDispatchCall(virNetServerProgramPtr prog, + virNetServerPtr server, + virNetServerClientPtr client, + virNetMessagePtr msg); + +/* + * @server: the unlocked server object + * @client: the unlocked client object + * @msg: the complete incoming message packet, with header already decoded + * + * This function is intended to be called from worker threads + * when an incoming message is ready to be dispatched for + * execution. + * + * Upon successful return the '@msg' instance will be released + * by this function (or more often, reused to send a reply). + * Upon failure, the '@msg' must be freed by the caller. + * + * Returns 0 if the message was dispatched, -1 upon fatal error + */ +int virNetServerProgramDispatch(virNetServerProgramPtr prog, + virNetServerPtr server, + virNetServerClientPtr client, + virNetMessagePtr msg) +{ + int ret = -1; + virNetMessageError rerr; + + memset(&rerr, 0, sizeof(rerr)); + + VIR_DEBUG("prog=%d ver=%d type=%d status=%d serial=%d proc=%d", + msg->header.prog, msg->header.vers, msg->header.type, + msg->header.status, msg->header.serial, msg->header.proc); + + /* Check version, etc. */ + if (msg->header.prog != prog->program) { + virNetError(VIR_ERR_RPC, + _("program mismatch (actual %x, expected %x)"), + msg->header.prog, prog->program); + goto error; + } + + if (msg->header.vers != prog->version) { + virNetError(VIR_ERR_RPC, + _("version mismatch (actual %x, expected %x)"), + msg->header.vers, prog->version); + goto error; + } + + switch (msg->header.type) { + case VIR_NET_CALL: + ret = virNetServerProgramDispatchCall(prog, server, client, msg); + break; + + case VIR_NET_STREAM: + /* Since stream data is non-acked, async, we may continue to receive + * stream packets after we closed down a stream. Just drop & ignore + * these. + */ + VIR_INFO("Ignoring unexpected stream data serial=%d proc=%d status=%d", + msg->header.serial, msg->header.proc, msg->header.status); + virNetMessageFree(msg); + ret = 0; + break; + + default: + virNetError(VIR_ERR_RPC, + _("Unexpected message type %u"), + msg->header.type); + goto error; + } + + return ret; + +error: + ret = virNetServerProgramSendReplyError(prog, client, msg, &rerr, &msg->header); + + return ret; +} + + +/* + * @server: the unlocked server object + * @client: the unlocked client object + * @msg: the complete incoming method call, with header already decoded + * + * This method is used to dispatch an message representing an + * incoming method call from a client. It decodes the payload + * to obtain method call arguments, invokves the method and + * then sends a reply packet with the return values + * + * Returns 0 if the reply was sent, or -1 upon fatal error + */ +static int +virNetServerProgramDispatchCall(virNetServerProgramPtr prog, + virNetServerPtr server, + virNetServerClientPtr client, + virNetMessagePtr msg) +{ + char *arg = NULL; + char *ret = NULL; + int rv = -1; + virNetServerProgramProcPtr dispatcher; + virNetMessageError rerr; + + memset(&rerr, 0, sizeof(rerr)); + + if (msg->header.status != VIR_NET_OK) { + virNetError(VIR_ERR_RPC, + _("Unexpected message status %u"), + msg->header.status); + goto error; + } + + dispatcher = virNetServerProgramGetProc(prog, msg->header.proc); + + if (!dispatcher) { + virNetError(VIR_ERR_RPC, + _("unknown procedure: %d"), + msg->header.proc); + goto error; + } + + /* If client is marked as needing auth, don't allow any RPC ops + * which are except for authentication ones + */ + if (virNetServerClientNeedAuth(client) && + dispatcher->needAuth) { + /* Explicitly *NOT* calling remoteDispatchAuthError() because + we want back-compatability with libvirt clients which don't + support the VIR_ERR_AUTH_FAILED error code */ + virNetError(VIR_ERR_RPC, + "%s", _("authentication required")); + goto error; + } + + if (VIR_ALLOC_N(arg, dispatcher->arg_len) < 0) { + virReportOOMError(); + goto error; + } + if (VIR_ALLOC_N(ret, dispatcher->ret_len) < 0) { + virReportOOMError(); + goto error; + } + + if (virNetMessageDecodePayload(msg, dispatcher->arg_filter, arg) < 0) + goto error; + + /* + * When the RPC handler is called: + * + * - Server object is unlocked + * - Client object is unlocked + * + * Without locking, it is safe to use: + * + * 'args and 'ret' + */ + rv = (dispatcher->func)(server, client, &msg->header, &rerr, arg, ret); + + xdr_free(dispatcher->arg_filter, arg); + + if (rv < 0) + goto error; + + /* Return header. We're re-using same message object, so + * only need to tweak type/status fields */ + /*msg->header.prog = msg->header.prog;*/ + /*msg->header.vers = msg->header.vers;*/ + /*msg->header.proc = msg->header.proc;*/ + msg->header.type = VIR_NET_REPLY; + /*msg->header.serial = msg->header.serial;*/ + msg->header.status = VIR_NET_OK; + + if (virNetMessageEncodeHeader(msg) < 0) { + xdr_free(dispatcher->ret_filter, ret); + goto error; + } + + if (virNetMessageEncodePayload(msg, dispatcher->ret_filter, ret) < 0) { + xdr_free(dispatcher->ret_filter, ret); + goto error; + } + + xdr_free(dispatcher->ret_filter, ret); + VIR_FREE(arg); + VIR_FREE(ret); + + /* Put reply on end of tx queue to send out */ + return virNetServerClientSendMessage(client, msg); + +error: + /* Bad stuff (de-)serializing message, but we have an + * RPC error message we can send back to the client */ + rv = virNetServerProgramSendReplyError(prog, client, msg, &rerr, &msg->header); + + VIR_FREE(arg); + VIR_FREE(ret); + + return rv; +} + + +int virNetServerProgramSendStreamData(virNetServerProgramPtr prog, + virNetServerClientPtr client, + virNetMessagePtr msg, + int procedure, + int serial, + const char *data, + size_t len) +{ + VIR_DEBUG("client=%p msg=%p data=%p len=%zu", client, msg, data, len); + + /* Return header. We're reusing same message object, so + * only need to tweak type/status fields */ + msg->header.prog = prog->program; + msg->header.vers = prog->version; + msg->header.proc = procedure; + msg->header.type = VIR_NET_STREAM; + msg->header.serial = serial; + /* + * NB + * data != NULL + len > 0 => REMOTE_CONTINUE (Sending back data) + * data != NULL + len == 0 => REMOTE_CONTINUE (Sending read EOF) + * data == NULL => REMOTE_OK (Sending finish handshake confirmation) + */ + msg->header.status = data ? VIR_NET_CONTINUE : VIR_NET_OK; + + if (virNetMessageEncodeHeader(msg) < 0) + return -1; + + if (data && len) { + if (virNetMessageEncodePayloadRaw(msg, data, len) < 0) + return -1; + + VIR_DEBUG("Total %zu", msg->bufferOffset); + } + + return virNetServerClientSendMessage(client, msg); +} + + +void virNetServerProgramFree(virNetServerProgramPtr prog) +{ + if (!prog) + return; + + VIR_DEBUG("prog=%p refs=%d", prog, prog->refs); + + prog->refs--; + if (prog->refs > 0) + return; + + VIR_FREE(prog); +} diff --git a/src/rpc/virnetserverprogram.h b/src/rpc/virnetserverprogram.h new file mode 100644 index 0000000..8277e7f --- /dev/null +++ b/src/rpc/virnetserverprogram.h @@ -0,0 +1,107 @@ +/* + * virnetserverprogram.h: generic network RPC server program + * + * Copyright (C) 2006-2010 Red Hat, Inc. + * Copyright (C) 2006 Daniel P. Berrange + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#ifndef __VIR_NET_PROGRAM_H__ +# define __VIR_NET_PROGRAM_H__ + +# include <stdbool.h> + +# include "virnetmessage.h" +# include "virnetserverclient.h" + +typedef struct _virNetServer virNetServer; +typedef virNetServer *virNetServerPtr; + +typedef struct _virNetServerService virNetServerService; +typedef virNetServerService *virNetServerServicePtr; + +typedef struct _virNetServerProgram virNetServerProgram; +typedef virNetServerProgram *virNetServerProgramPtr; + +typedef struct _virNetServerProgramProc virNetServerProgramProc; +typedef virNetServerProgramProc *virNetServerProgramProcPtr; + +typedef struct _virNetServerProgramErrorHandler virNetServerProgramErrorHander; +typedef virNetServerProgramErrorHander *virNetServerProgramErrorHanderPtr; + +typedef int (*virNetServerProgramDispatchFunc)(virNetServerPtr server, + virNetServerClientPtr client, + virNetMessageHeaderPtr hdr, + virNetMessageErrorPtr rerr, + void *args, + void *ret); + +struct _virNetServerProgramProc { + virNetServerProgramDispatchFunc func; + size_t arg_len; + xdrproc_t arg_filter; + size_t ret_len; + xdrproc_t ret_filter; + bool needAuth; +}; + +virNetServerProgramPtr virNetServerProgramNew(unsigned program, + unsigned version, + virNetServerProgramProcPtr procs, + size_t nprocs); + +int virNetServerProgramGetID(virNetServerProgramPtr prog); +int virNetServerProgramGetVersion(virNetServerProgramPtr prog); + +void virNetServerProgramRef(virNetServerProgramPtr prog); + +int virNetServerProgramMatches(virNetServerProgramPtr prog, + virNetMessagePtr msg); + +int virNetServerProgramDispatch(virNetServerProgramPtr prog, + virNetServerPtr server, + virNetServerClientPtr client, + virNetMessagePtr msg); + +int virNetServerProgramSendReplyError(virNetServerProgramPtr prog, + virNetServerClientPtr client, + virNetMessagePtr msg, + virNetMessageErrorPtr rerr, + virNetMessageHeaderPtr req); + +int virNetServerProgramSendStreamError(virNetServerProgramPtr prog, + virNetServerClientPtr client, + virNetMessagePtr msg, + virNetMessageErrorPtr rerr, + int procedure, + int serial); + +int virNetServerProgramSendStreamData(virNetServerProgramPtr prog, + virNetServerClientPtr client, + virNetMessagePtr msg, + int procedure, + int serial, + const char *data, + size_t len); + +void virNetServerProgramFree(virNetServerProgramPtr prog); + + + + +#endif /* __VIR_NET_SERVER_PROGRAM_H__ */ diff --git a/src/rpc/virnetserverservice.c b/src/rpc/virnetserverservice.c new file mode 100644 index 0000000..3f6cf4b --- /dev/null +++ b/src/rpc/virnetserverservice.c @@ -0,0 +1,247 @@ +/* + * virnetserverservice.c: generic network RPC server service + * + * Copyright (C) 2006-2010 Red Hat, Inc. + * Copyright (C) 2006 Daniel P. Berrange + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#include <config.h> + +#include "virnetserverservice.h" + +#include "memory.h" +#include "virterror_internal.h" + + +#define VIR_FROM_THIS VIR_FROM_RPC + +struct _virNetServerService { + int refs; + + size_t nsocks; + virNetSocketPtr *socks; + + int auth; + bool readonly; + + virNetTLSContextPtr tls; + + virNetServerServiceDispatchFunc dispatchFunc; + void *dispatchOpaque; +}; + + + +static void virNetServerServiceAccept(virNetSocketPtr sock, + int events ATTRIBUTE_UNUSED, + void *opaque) +{ + virNetServerServicePtr svc = opaque; + virNetServerClientPtr client = NULL; + virNetSocketPtr clientsock = NULL; + + if (virNetSocketAccept(sock, &clientsock) < 0) + goto error; + + if (!clientsock) /* Connection already went away */ + goto cleanup; + + if (!(client = virNetServerClientNew(clientsock, + svc->auth, + svc->readonly, + svc->tls))) + goto error; + + if (!svc->dispatchFunc) + goto error; + + if (svc->dispatchFunc(svc, client, svc->dispatchOpaque) < 0) + virNetServerClientClose(client); + + virNetServerClientFree(client); + +cleanup: + return; + +error: + virNetSocketFree(clientsock); +} + + +virNetServerServicePtr virNetServerServiceNewTCP(const char *nodename, + const char *service, + int auth, + bool readonly, + virNetTLSContextPtr tls) +{ + virNetServerServicePtr svc; + size_t i; + + if (VIR_ALLOC(svc) < 0) + goto no_memory; + + svc->refs = 1; + svc->auth = auth; + svc->readonly = readonly; + svc->tls = tls; + if (tls) + virNetTLSContextRef(tls); + + if (virNetSocketNewListenTCP(nodename, + service, + &svc->socks, + &svc->nsocks) < 0) + goto error; + + for (i = 0 ; i < svc->nsocks ; i++) { + if (virNetSocketListen(svc->socks[i]) < 0) + goto error; + + /* IO callback is initially disabled, until we're ready + * to deal with incoming clients */ + if (virNetSocketAddIOCallback(svc->socks[i], + 0, + virNetServerServiceAccept, + svc) < 0) + goto error; + } + + + return svc; + +no_memory: + virReportOOMError(); +error: + virNetServerServiceFree(svc); + return NULL; +} + + +virNetServerServicePtr virNetServerServiceNewUNIX(const char *path, + mode_t mask, + gid_t grp, + int auth, + bool readonly, + virNetTLSContextPtr tls) +{ + virNetServerServicePtr svc; + int i; + + if (VIR_ALLOC(svc) < 0) + goto no_memory; + + svc->refs = 1; + svc->auth = auth; + svc->readonly = readonly; + svc->tls = tls; + if (tls) + virNetTLSContextRef(tls); + + svc->nsocks = 1; + if (VIR_ALLOC_N(svc->socks, svc->nsocks) < 0) + goto no_memory; + + if (virNetSocketNewListenUNIX(path, + mask, + grp, + &svc->socks[0]) < 0) + goto error; + + for (i = 0 ; i < svc->nsocks ; i++) { + if (virNetSocketListen(svc->socks[i]) < 0) + goto error; + + /* IO callback is initially disabled, until we're ready + * to deal with incoming clients */ + if (virNetSocketAddIOCallback(svc->socks[i], + 0, + virNetServerServiceAccept, + svc) < 0) + goto error; + } + + + return svc; + +no_memory: + virReportOOMError(); +error: + virNetServerServiceFree(svc); + return NULL; +} + + +int virNetServerServiceGetAuth(virNetServerServicePtr svc) +{ + return svc->auth; +} + + +bool virNetServerServiceIsReadonly(virNetServerServicePtr svc) +{ + return svc->readonly; +} + + +void virNetServerServiceRef(virNetServerServicePtr svc) +{ + svc->refs++; +} + + +void virNetServerServiceSetDispatcher(virNetServerServicePtr svc, + virNetServerServiceDispatchFunc func, + void *opaque) +{ + svc->dispatchFunc = func; + svc->dispatchOpaque = opaque; +} + + +void virNetServerServiceFree(virNetServerServicePtr svc) +{ + int i; + + if (!svc) + return; + + svc->refs--; + if (svc->refs > 0) + return; + + for (i = 0 ; i < svc->nsocks ; i++) + virNetSocketFree(svc->socks[i]); + VIR_FREE(svc->socks); + + virNetTLSContextFree(svc->tls); + + VIR_FREE(svc); +} + +void virNetServerServiceToggle(virNetServerServicePtr svc, + bool enabled) +{ + int i; + + for (i = 0 ; i < svc->nsocks ; i++) + virNetSocketUpdateIOCallback(svc->socks[i], + enabled ? + VIR_EVENT_HANDLE_READABLE : + 0); +} diff --git a/src/rpc/virnetserverservice.h b/src/rpc/virnetserverservice.h new file mode 100644 index 0000000..b59c8fa --- /dev/null +++ b/src/rpc/virnetserverservice.h @@ -0,0 +1,65 @@ +/* + * virnetserverservice.h: generic network RPC server service + * + * Copyright (C) 2006-2010 Red Hat, Inc. + * Copyright (C) 2006 Daniel P. Berrange + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#ifndef __VIR_NET_SERVER_SERVICE_H__ +# define __VIR_NET_SERVER_SERVICE_H__ + +# include "virnetserverprogram.h" + +enum { + VIR_NET_SERVER_SERVICE_AUTH_NONE = 0, + VIR_NET_SERVER_SERVICE_AUTH_SASL, + VIR_NET_SERVER_SERVICE_AUTH_POLKIT, +}; + +typedef int (*virNetServerServiceDispatchFunc)(virNetServerServicePtr svc, + virNetServerClientPtr client, + void *opaque); + +virNetServerServicePtr virNetServerServiceNewTCP(const char *nodename, + const char *service, + int auth, + bool readonly, + virNetTLSContextPtr tls); +virNetServerServicePtr virNetServerServiceNewUNIX(const char *path, + mode_t mask, + gid_t grp, + int auth, + bool readonly, + virNetTLSContextPtr tls); + +int virNetServerServiceGetAuth(virNetServerServicePtr svc); +bool virNetServerServiceIsReadonly(virNetServerServicePtr svc); + +void virNetServerServiceRef(virNetServerServicePtr svc); + +void virNetServerServiceSetDispatcher(virNetServerServicePtr svc, + virNetServerServiceDispatchFunc func, + void *opaque); + +void virNetServerServiceFree(virNetServerServicePtr svc); + +void virNetServerServiceToggle(virNetServerServicePtr svc, + bool enabled); + +#endif -- 1.7.4

Allow RPC servers to advertize themselves using MDNS, via Avahi * src/rpc/virnetserver.c, src/rpc/virnetserver.h: Allow registration of MDNS services via avahi * src/rpc/virnetserverservice.c, src/rpc/virnetserverservice.h: Add API to fetch the listen port number * src/rpc/virnetsocket.c, src/rpc/virnetsocket.h: Add API to fetch the local port number * src/rpc/virnetservermdns.c, src/rpc/virnetservermdns.h: Represent an MDNS advertisement --- src/Makefile.am | 9 + src/rpc/virnetserver.c | 47 +++- src/rpc/virnetserver.h | 4 +- src/rpc/virnetservermdns.c | 616 +++++++++++++++++++++++++++++++++++++++++ src/rpc/virnetservermdns.h | 108 +++++++ src/rpc/virnetserverservice.c | 8 + src/rpc/virnetserverservice.h | 2 + src/rpc/virnetsocket.c | 6 + src/rpc/virnetsocket.h | 2 + 9 files changed, 800 insertions(+), 2 deletions(-) create mode 100644 src/rpc/virnetservermdns.c create mode 100644 src/rpc/virnetservermdns.h diff --git a/src/Makefile.am b/src/Makefile.am index b0a96b8..3a724d1 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -1231,10 +1231,19 @@ libvirt_net_rpc_server_la_SOURCES = \ rpc/virnetserverservice.h rpc/virnetserverservice.c \ rpc/virnetserverclient.h rpc/virnetserverclient.c \ rpc/virnetserver.h rpc/virnetserver.c +if HAVE_AVAHI +libvirt_net_rpc_server_la_SOURCES += \ + rpc/virnetservermdns.h rpc/virnetservermdns.c +else +EXTRA_DIST += \ + rpc/virnetservermdns.h rpc/virnetservermdns.c +endif libvirt_net_rpc_server_la_CFLAGS = \ + $(AVAHI_CFLAGS) \ $(AM_CFLAGS) libvirt_net_rpc_server_la_LDFLAGS = \ $(AM_LDFLAGS) \ + $(AVAHI_LIBS) \ $(CYGWIN_EXTRA_LDFLAGS) \ $(MINGW_EXTRA_LDFLAGS)l libvirt_net_rpc_server_la_LIBADD = \ diff --git a/src/rpc/virnetserver.c b/src/rpc/virnetserver.c index 5dbd7a9..e857ceb 100644 --- a/src/rpc/virnetserver.c +++ b/src/rpc/virnetserver.c @@ -35,6 +35,9 @@ #include "util.h" #include "files.h" #include "event.h" +#if HAVE_AVAHI +#include "virnetservermdns.h" +#endif #define VIR_FROM_THIS VIR_FROM_RPC @@ -75,6 +78,12 @@ struct _virNetServer { int sigwrite; int sigwatch; + char *mdnsGroupName; +#if HAVE_AVAHI + virNetServerMDNSPtr mdns; + virNetServerMDNSGroupPtr mdnsGroup; +#endif + size_t nservices; virNetServerServicePtr *services; @@ -261,6 +270,7 @@ static void virNetServerFatalSignal(int sig, siginfo_t * siginfo ATTRIBUTE_UNUSE virNetServerPtr virNetServerNew(size_t min_workers, size_t max_workers, size_t max_clients, + const char *mdnsGroupName, virNetServerClientInitHook clientInitHook) { virNetServerPtr srv; @@ -283,6 +293,19 @@ virNetServerPtr virNetServerNew(size_t min_workers, srv->clientInitHook = clientInitHook; srv->privileged = geteuid() == 0 ? true : false; + if (!(srv->mdnsGroupName = strdup(mdnsGroupName))) { + virReportOOMError(); + goto error; + } +#if HAVE_AVAHI + if (srv->mdnsGroupName) { + if (!(srv->mdns = virNetServerMDNSNew())) + goto error; + if (!(srv->mdnsGroup = virNetServerMDNSAddGroup(srv->mdns, mdnsGroupName))) + goto error; + } +#endif + if (virMutexInit(&srv->lock) < 0) { virNetError(VIR_ERR_INTERNAL_ERROR, "%s", _("cannot initialize mutex")); @@ -504,13 +527,26 @@ error: int virNetServerAddService(virNetServerPtr srv, - virNetServerServicePtr svc) + virNetServerServicePtr svc, + const char *mdnsEntryName ATTRIBUTE_UNUSED) { virNetServerLock(srv); if (VIR_EXPAND_N(srv->services, srv->nservices, 1) < 0) goto no_memory; +#if HAVE_AVAHI + if (mdnsEntryName) { + int port = virNetServerServiceGetPort(svc); + virNetServerMDNSEntryPtr entry; + + if (!(entry = virNetServerMDNSAddEntry(srv->mdnsGroup, + mdnsEntryName, + port))) + goto error; + } +#endif + srv->services[srv->nservices-1] = svc; virNetServerServiceRef(svc); @@ -523,6 +559,9 @@ int virNetServerAddService(virNetServerPtr srv, no_memory: virReportOOMError(); +#if HAVE_AVAHI +error: +#endif virNetServerUnlock(srv); return -1; } @@ -592,6 +631,12 @@ void virNetServerRun(virNetServerPtr srv) virNetServerLock(srv); +#if HAVE_AVAHI + if (srv->mdns && + virNetServerMDNSStart(srv->mdns) < 0) + goto cleanup; +#endif + if (srv->autoShutdownTimeout && (timerid = virEventAddTimeout(-1, virNetServerAutoShutdownTimer, diff --git a/src/rpc/virnetserver.h b/src/rpc/virnetserver.h index 8b8b6a9..75796e9 100644 --- a/src/rpc/virnetserver.h +++ b/src/rpc/virnetserver.h @@ -38,6 +38,7 @@ typedef int (*virNetServerClientInitHook)(virNetServerPtr srv, virNetServerPtr virNetServerNew(size_t min_workers, size_t max_workers, size_t max_clients, + const char *mdnsGroupName, virNetServerClientInitHook clientInitHook); typedef int (*virNetServerAutoShutdownFunc)(virNetServerPtr srv, void *opaque); @@ -59,7 +60,8 @@ int virNetServerAddSignalHandler(virNetServerPtr srv, void *opaque); int virNetServerAddService(virNetServerPtr srv, - virNetServerServicePtr svc); + virNetServerServicePtr svc, + const char *mdnsEntryName); int virNetServerAddProgram(virNetServerPtr srv, virNetServerProgramPtr prog); diff --git a/src/rpc/virnetservermdns.c b/src/rpc/virnetservermdns.c new file mode 100644 index 0000000..e449ffb --- /dev/null +++ b/src/rpc/virnetservermdns.c @@ -0,0 +1,616 @@ +/* + * virnetservermdns.c: advertise server sockets + * + * Copyright (C) 2007 Daniel P. Berrange + * + * Derived from Avahi example service provider code. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#include <config.h> + +#include <time.h> +#include <stdio.h> +#include <stdlib.h> + +#include <avahi-client/client.h> +#include <avahi-client/publish.h> + +#include <avahi-common/alternative.h> +#include <avahi-common/simple-watch.h> +#include <avahi-common/malloc.h> +#include <avahi-common/error.h> +#include <avahi-common/timeval.h> + +#include "virnetservermdns.h" +#include "event.h" +#include "event_poll.h" +#include "memory.h" +#include "virterror_internal.h" +#include "logging.h" + +#define VIR_FROM_THIS VIR_FROM_RPC + +#define virNetError(code, ...) \ + virReportErrorHelper(NULL, VIR_FROM_RPC, code, __FILE__, \ + __FUNCTION__, __LINE__, __VA_ARGS__) + +struct _virNetServerMDNSEntry { + char *type; + int port; + virNetServerMDNSEntryPtr next; +}; + +struct _virNetServerMDNSGroup { + virNetServerMDNSPtr mdns; + AvahiEntryGroup *handle; + char *name; + virNetServerMDNSEntryPtr entry; + virNetServerMDNSGroupPtr next; +}; + +struct _virNetServerMDNS { + AvahiClient *client; + AvahiPoll *poller; + virNetServerMDNSGroupPtr group; +}; + +/* Avahi API requires this struct name in the app :-( */ +struct AvahiWatch { + int watch; + int fd; + int revents; + AvahiWatchCallback callback; + void *userdata; +}; + +/* Avahi API requires this struct name in the app :-( */ +struct AvahiTimeout { + int timer; + AvahiTimeoutCallback callback; + void *userdata; +}; + +static void virNetServerMDNSCreateServices(virNetServerMDNSGroupPtr group); + +/* Called whenever the entry group state changes */ +static void virNetServerMDNSGroupCallback(AvahiEntryGroup *g ATTRIBUTE_UNUSED, + AvahiEntryGroupState state, + void *data) +{ + virNetServerMDNSGroupPtr group = data; + + switch (state) { + case AVAHI_ENTRY_GROUP_ESTABLISHED: + /* The entry group has been established successfully */ + VIR_DEBUG("Group '%s' established", group->name); + break; + + case AVAHI_ENTRY_GROUP_COLLISION: + { + char *n; + + /* A service name collision happened. Let's pick a new name */ + n = avahi_alternative_service_name(group->name); + VIR_FREE(group->name); + group->name = n; + + VIR_DEBUG("Group name collision, renaming service to '%s'", group->name); + + /* And recreate the services */ + virNetServerMDNSCreateServices(group); + } + break; + + case AVAHI_ENTRY_GROUP_FAILURE : + VIR_DEBUG("Group failure: %s", + avahi_strerror(avahi_client_errno(group->mdns->client))); + + /* Some kind of failure happened while we were registering our services */ + //avahi_simple_poll_quit(simple_poll); + break; + + case AVAHI_ENTRY_GROUP_UNCOMMITED: + case AVAHI_ENTRY_GROUP_REGISTERING: + ; + } +} + +static void virNetServerMDNSCreateServices(virNetServerMDNSGroupPtr group) +{ + virNetServerMDNSPtr mdns = group->mdns; + virNetServerMDNSEntryPtr entry; + int ret; + VIR_DEBUG("Adding services to '%s'", group->name); + + /* If we've no services to advertise, just reset the group to make + * sure it is emptied of any previously advertised services */ + if (!group->entry) { + if (group->handle) + avahi_entry_group_reset(group->handle); + return; + } + + /* If this is the first time we're called, let's create a new entry group */ + if (!group->handle) { + VIR_DEBUG("Creating initial group %s", group->name); + if (!(group->handle = + avahi_entry_group_new(mdns->client, + virNetServerMDNSGroupCallback, + group))) { + VIR_DEBUG("avahi_entry_group_new() failed: %s", + avahi_strerror(avahi_client_errno(mdns->client))); + return; + } + } + + entry = group->entry; + while (entry) { + if ((ret = avahi_entry_group_add_service(group->handle, + AVAHI_IF_UNSPEC, + AVAHI_PROTO_UNSPEC, + 0, + group->name, + entry->type, + NULL, + NULL, + entry->port, + NULL)) < 0) { + VIR_DEBUG("Failed to add %s service on port %d: %s", + entry->type, entry->port, avahi_strerror(ret)); + avahi_entry_group_reset(group->handle); + return; + } + entry = entry->next; + } + + /* Tell the server to register the service */ + if ((ret = avahi_entry_group_commit(group->handle)) < 0) { + avahi_entry_group_reset(group->handle); + VIR_DEBUG("Failed to commit entry_group: %s", + avahi_strerror(ret)); + return; + } +} + + +static void virNetServerMDNSClientCallback(AvahiClient *c, + AvahiClientState state, + void *data) +{ + virNetServerMDNSPtr mdns = data; + virNetServerMDNSGroupPtr group; + if (!mdns->client) + mdns->client = c; + + VIR_DEBUG("Callback state=%d", state); + + /* Called whenever the client or server state changes */ + switch (state) { + case AVAHI_CLIENT_S_RUNNING: + /* The server has startup successfully and registered its host + * name on the network, so it's time to create our services */ + VIR_DEBUG("Client running %p", mdns->client); + group = mdns->group; + while (group) { + virNetServerMDNSCreateServices(group); + group = group->next; + } + break; + + case AVAHI_CLIENT_FAILURE: + VIR_DEBUG("Client failure: %s", + avahi_strerror(avahi_client_errno(c))); + virNetServerMDNSStop(mdns); + virNetServerMDNSStart(mdns); + break; + + case AVAHI_CLIENT_S_COLLISION: + /* Let's drop our registered services. When the server is back + * in AVAHI_SERVER_RUNNING state we will register them + * again with the new host name. */ + + /* Fallthrough */ + + case AVAHI_CLIENT_S_REGISTERING: + /* The server records are now being established. This + * might be caused by a host name change. We need to wait + * for our own records to register until the host name is + * properly established. */ + VIR_DEBUG("Client collision/connecting %p", mdns->client); + group = mdns->group; + while (group) { + if (group->handle) + avahi_entry_group_reset(group->handle); + group = group->next; + } + break; + + case AVAHI_CLIENT_CONNECTING: + VIR_DEBUG("Client connecting.... %p", mdns->client); + ; + } +} + + +static void virNetServerMDNSWatchDispatch(int watch, int fd, int events, void *opaque) +{ + AvahiWatch *w = opaque; + int fd_events = virEventPollToNativeEvents(events); + VIR_DEBUG("Dispatch watch %d FD %d Event %d", watch, fd, fd_events); + w->revents = fd_events; + w->callback(w, fd, fd_events, w->userdata); +} + +static void virNetServerMDNSWatchDofree(void *w) +{ + VIR_FREE(w); +} + + +static AvahiWatch *virNetServerMDNSWatchNew(const AvahiPoll *api ATTRIBUTE_UNUSED, + int fd, AvahiWatchEvent event, + AvahiWatchCallback cb, void *userdata) +{ + AvahiWatch *w; + virEventHandleType hEvents; + if (VIR_ALLOC(w) < 0) { + virReportOOMError(); + return NULL; + } + + w->fd = fd; + w->revents = 0; + w->callback = cb; + w->userdata = userdata; + + VIR_DEBUG("New handle %p FD %d Event %d", w, w->fd, event); + hEvents = virEventPollFromNativeEvents(event); + if ((w->watch = virEventAddHandle(fd, hEvents, + virNetServerMDNSWatchDispatch, + w, + virNetServerMDNSWatchDofree)) < 0) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("Failed to add watch for fd %d events %d"), fd, hEvents); + VIR_FREE(w); + return NULL; + } + + return w; +} + +static void virNetServerMDNSWatchUpdate(AvahiWatch *w, AvahiWatchEvent event) +{ + VIR_DEBUG("Update handle %p FD %d Event %d", w, w->fd, event); + virEventUpdateHandle(w->watch, event); +} + +static AvahiWatchEvent virNetServerMDNSWatchGetEvents(AvahiWatch *w) +{ + VIR_DEBUG("Get handle events %p %d", w, w->fd); + return w->revents; +} + +static void virNetServerMDNSWatchFree(AvahiWatch *w) +{ + VIR_DEBUG("Free handle %p %d", w, w->fd); + virEventRemoveHandle(w->watch); +} + +static void virNetServerMDNSTimeoutDispatch(int timer ATTRIBUTE_UNUSED, void *opaque) +{ + AvahiTimeout *t = (AvahiTimeout*)opaque; + VIR_DEBUG("Dispatch timeout %p %d", t, timer); + virEventUpdateTimeout(t->timer, -1); + t->callback(t, t->userdata); +} + +static void virNetServerMDNSTimeoutDofree(void *t) +{ + VIR_FREE(t); +} + +static AvahiTimeout *virNetServerMDNSTimeoutNew(const AvahiPoll *api ATTRIBUTE_UNUSED, + const struct timeval *tv, + AvahiTimeoutCallback cb, + void *userdata) +{ + AvahiTimeout *t; + struct timeval now; + long long nowms, thenms, timeout; + VIR_DEBUG("Add timeout TV %p", tv); + if (VIR_ALLOC(t) < 0) { + virReportOOMError(); + return NULL; + } + + if (gettimeofday(&now, NULL) < 0) { + virReportSystemError(errno, "%s", + _("Unable to get current time")); + VIR_FREE(t); + return NULL; + } + + VIR_DEBUG("Trigger timed for %d %d %d %d", + (int)now.tv_sec, (int)now.tv_usec, + (int)(tv ? tv->tv_sec : 0), (int)(tv ? tv->tv_usec : 0)); + nowms = (now.tv_sec * 1000ll) + (now.tv_usec / 1000ll); + if (tv) { + thenms = (tv->tv_sec * 1000ll) + (tv->tv_usec/1000ll); + timeout = thenms > nowms ? nowms - thenms : 0; + if (timeout < 0) + timeout = 0; + } else { + timeout = -1; + } + + t->timer = virEventAddTimeout(timeout, + virNetServerMDNSTimeoutDispatch, + t, + virNetServerMDNSTimeoutDofree); + t->callback = cb; + t->userdata = userdata; + + if (t->timer < 0) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("Failed to add timer with timeout %d"), (int)timeout); + VIR_FREE(t); + return NULL; + } + + return t; +} + +static void virNetServerMDNSTimeoutUpdate(AvahiTimeout *t, const struct timeval *tv) +{ + struct timeval now; + long long nowms, thenms, timeout; + VIR_DEBUG("Update timeout %p TV %p", t, tv); + if (gettimeofday(&now, NULL) < 0) { + VIR_FREE(t); + return; + } + + nowms = (now.tv_sec * 1000ll) + (now.tv_usec / 1000ll); + if (tv) { + thenms = ((tv->tv_sec * 1000ll) + (tv->tv_usec/1000ll)); + timeout = thenms > nowms ? nowms - thenms : 0; + if (timeout < 0) + timeout = 0; + } else { + timeout = -1; + } + + virEventUpdateTimeout(t->timer, timeout); +} + +static void virNetServerMDNSTimeoutFree(AvahiTimeout *t) +{ + VIR_DEBUG("Free timeout %p", t); + virEventRemoveTimeout(t->timer); +} + + +static AvahiPoll *virNetServerMDNSCreatePoll(void) +{ + AvahiPoll *p; + if (VIR_ALLOC(p) < 0) { + virReportOOMError(); + return NULL; + } + + p->userdata = NULL; + + p->watch_new = virNetServerMDNSWatchNew; + p->watch_update = virNetServerMDNSWatchUpdate; + p->watch_get_events = virNetServerMDNSWatchGetEvents; + p->watch_free = virNetServerMDNSWatchFree; + + p->timeout_new = virNetServerMDNSTimeoutNew; + p->timeout_update = virNetServerMDNSTimeoutUpdate; + p->timeout_free = virNetServerMDNSTimeoutFree; + + return p; +} + + +virNetServerMDNS *virNetServerMDNSNew(void) +{ + virNetServerMDNS *mdns; + if (VIR_ALLOC(mdns) < 0) + return NULL; + + /* Allocate main loop object */ + if (!(mdns->poller = virNetServerMDNSCreatePoll())) { + VIR_FREE(mdns); + return NULL; + } + + return mdns; +} + + +int virNetServerMDNSStart(virNetServerMDNS *mdns) +{ + int error; + VIR_DEBUG("Starting client %p", mdns); + mdns->client = avahi_client_new(mdns->poller, + AVAHI_CLIENT_NO_FAIL, + virNetServerMDNSClientCallback, + mdns, &error); + + if (!mdns->client) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("Failed to create mDNS client: %s"), + avahi_strerror(error)); + return -1; + } + + return 0; +} + + +virNetServerMDNSGroupPtr virNetServerMDNSAddGroup(virNetServerMDNS *mdns, + const char *name) +{ + virNetServerMDNSGroupPtr group; + + VIR_DEBUG("Adding group '%s'", name); + if (VIR_ALLOC(group) < 0) { + virReportOOMError(); + return NULL; + } + + if (!(group->name = strdup(name))) { + VIR_FREE(group); + virReportOOMError(); + return NULL; + } + group->mdns = mdns; + group->next = mdns->group; + mdns->group = group; + return group; +} + + +void virNetServerMDNSRemoveGroup(virNetServerMDNSPtr mdns, + virNetServerMDNSGroupPtr group) +{ + virNetServerMDNSGroupPtr tmp = mdns->group, prev = NULL; + + while (tmp) { + if (tmp == group) { + VIR_FREE(group->name); + if (prev) + prev->next = group->next; + else + group->mdns->group = group->next; + VIR_FREE(group); + return; + } + prev = tmp; + tmp = tmp->next; + } +} + + +virNetServerMDNSEntryPtr virNetServerMDNSAddEntry(virNetServerMDNSGroupPtr group, + const char *type, + int port) +{ + virNetServerMDNSEntryPtr entry; + + VIR_DEBUG("Adding entry %s %d to group %s", type, port, group->name); + if (VIR_ALLOC(entry) < 0) { + virReportOOMError(); + return NULL; + } + + entry->port = port; + if (!(entry->type = strdup(type))) { + VIR_FREE(entry); + virReportOOMError(); + return NULL; + } + entry->next = group->entry; + group->entry = entry; + return entry; +} + + +void virNetServerMDNSRemoveEntry(virNetServerMDNSGroupPtr group, + virNetServerMDNSEntryPtr entry) +{ + virNetServerMDNSEntryPtr tmp = group->entry, prev = NULL; + + while (tmp) { + if (tmp == entry) { + VIR_FREE(entry->type); + if (prev) + prev->next = entry->next; + else + group->entry = entry->next; + return; + } + prev = tmp; + tmp = tmp->next; + } +} + + +void virNetServerMDNSStop(virNetServerMDNSPtr mdns) +{ + virNetServerMDNSGroupPtr group = mdns->group; + while (group) { + if (group->handle) { + avahi_entry_group_free(group->handle); + group->handle = NULL; + } + group = group->next; + } + if (mdns->client) + avahi_client_free(mdns->client); + mdns->client = NULL; +} + + +void virNetServerMDNSFree(virNetServerMDNSPtr mdns) +{ + virNetServerMDNSGroupPtr group, tmp; + + if (!mdns) + return; + + group = mdns->group; + while (group) { + tmp = group->next; + virNetServerMDNSGroupFree(group); + group = tmp; + } + + VIR_FREE(mdns); +} + + +void virNetServerMDNSGroupFree(virNetServerMDNSGroupPtr grp) +{ + virNetServerMDNSEntryPtr entry, tmp; + + if (!grp) + return; + + entry = grp->entry; + while (entry) { + tmp = entry->next; + virNetServerMDNSEntryFree(entry); + entry = tmp; + } + + VIR_FREE(grp); +} + + +void virNetServerMDNSEntryFree(virNetServerMDNSEntryPtr entry) +{ + if (!entry) + return; + + VIR_FREE(entry); +} + + diff --git a/src/rpc/virnetservermdns.h b/src/rpc/virnetservermdns.h new file mode 100644 index 0000000..9284f4a --- /dev/null +++ b/src/rpc/virnetservermdns.h @@ -0,0 +1,108 @@ +/* + * virnetservermdns.c: advertise server sockets + * + * Copyright (C) 2007 Daniel P. Berrange + * + * Derived from Avahi example service provider code. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#ifndef __VIR_NET_SERVER_MDNS_H__ +# define __VIR_NET_SERVER_MDNS_H__ + +#include "internal.h" + +typedef struct _virNetServerMDNS virNetServerMDNS; +typedef virNetServerMDNS *virNetServerMDNSPtr; +typedef struct _virNetServerMDNSGroup virNetServerMDNSGroup; +typedef virNetServerMDNSGroup *virNetServerMDNSGroupPtr; +typedef struct _virNetServerMDNSEntry virNetServerMDNSEntry; +typedef virNetServerMDNSEntry *virNetServerMDNSEntryPtr; + + +/** + * Prepares a new mdns manager object for use + */ +virNetServerMDNSPtr virNetServerMDNSNew(void); + +/** + * Starts the mdns client, advertising any groups/entries currently registered + * + * @mdns: manager to start advertising + * + * Starts the mdns client. Services may not be immediately visible, since + * it may asynchronously wait for the mdns service to startup + * + * returns -1 upon failure, 0 upon success. + */ +int virNetServerMDNSStart(virNetServerMDNSPtr mdns); + +/** + * Stops the mdns client, removing any advertisements + * + * @mdns: manager to start advertising + * + */ +void virNetServerMDNSStop(virNetServerMDNSPtr mdns); + +/** + * Adds a group container for advertisement + * + * @mdns manager to attach the group to + * @name unique human readable service name + * + * returns the group record, or NULL upon failure + */ +virNetServerMDNSGroupPtr virNetServerMDNSAddGroup(virNetServerMDNSPtr mdns, + const char *name); + +/** + * Removes a group container from advertisement + * + * @mdns amanger to detach group from + * @group group to remove + */ +void virNetServerMDNSRemoveGroup(virNetServerMDNSPtr mdns, + virNetServerMDNSGroupPtr group); + +/** + * Adds a service entry in a group + * + * @group group to attach the entry to + * @type service type string + * @port tcp port number + * + * returns the service record, or NULL upon failure + */ +virNetServerMDNSEntryPtr virNetServerMDNSAddEntry(virNetServerMDNSGroupPtr group, + const char *type, int port); + +/** + * Removes a service entry from a group + * + * @group group to detach service entry from + * @entry service entry to remove + */ +void virNetServerMDNSRemoveEntry(virNetServerMDNSGroupPtr group, + virNetServerMDNSEntryPtr entry); + +void virNetServerMDNSFree(virNetServerMDNSPtr ptr); +void virNetServerMDNSGroupFree(virNetServerMDNSGroupPtr ptr); +void virNetServerMDNSEntryFree(virNetServerMDNSEntryPtr ptr); + +#endif /* __VIR_NET_SERVER_MDNS_H__ */ diff --git a/src/rpc/virnetserverservice.c b/src/rpc/virnetserverservice.c index 3f6cf4b..48c00b1 100644 --- a/src/rpc/virnetserverservice.c +++ b/src/rpc/virnetserverservice.c @@ -187,6 +187,14 @@ error: } +int virNetServerServiceGetPort(virNetServerServicePtr svc) +{ + /* We're assuming if there are multiple sockets + * for IPv4 & 6, then they are all on same port */ + return virNetSocketGetPort(svc->socks[0]); +} + + int virNetServerServiceGetAuth(virNetServerServicePtr svc) { return svc->auth; diff --git a/src/rpc/virnetserverservice.h b/src/rpc/virnetserverservice.h index b59c8fa..121e2f4 100644 --- a/src/rpc/virnetserverservice.h +++ b/src/rpc/virnetserverservice.h @@ -48,6 +48,8 @@ virNetServerServicePtr virNetServerServiceNewUNIX(const char *path, bool readonly, virNetTLSContextPtr tls); +int virNetServerServiceGetPort(virNetServerServicePtr svc); + int virNetServerServiceGetAuth(virNetServerServicePtr svc); bool virNetServerServiceIsReadonly(virNetServerServicePtr svc); diff --git a/src/rpc/virnetsocket.c b/src/rpc/virnetsocket.c index a5ee861..326e219 100644 --- a/src/rpc/virnetsocket.c +++ b/src/rpc/virnetsocket.c @@ -671,6 +671,12 @@ bool virNetSocketIsLocal(virNetSocketPtr sock) } +int virNetSocketGetPort(virNetSocketPtr sock) +{ + return virSocketGetPort(&sock->localAddr); +} + + #ifdef SO_PEERCRED int virNetSocketGetLocalIdentity(virNetSocketPtr sock, uid_t *uid, diff --git a/src/rpc/virnetsocket.h b/src/rpc/virnetsocket.h index 1be423b..0dcb553 100644 --- a/src/rpc/virnetsocket.h +++ b/src/rpc/virnetsocket.h @@ -77,6 +77,8 @@ int virNetSocketNewConnectExternal(const char **cmdargv, int virNetSocketGetFD(virNetSocketPtr sock); bool virNetSocketIsLocal(virNetSocketPtr sock); +int virNetSocketGetPort(virNetSocketPtr sock); + int virNetSocketGetLocalIdentity(virNetSocketPtr sock, uid_t *uid, pid_t *pid); -- 1.7.4

To facilitate creation of new clients using XDR RPC services, pull alot of the remote driver code into a set of reusable objects. - virNetClient: Encapsulates a socket connection to a remote RPC server. Handles all the network I/O for reading/writing RPC messages. Delegates RPC encoding and decoding to the registered programs - virNetClientProgram: Handles processing and dispatch of RPC messages for a single RPC (program,version). A program can register to receive async events from a client - virNetClientStream: Handles generic I/O stream integration to RPC layer Each new client program now merely needs to define the list of RPC procedures & events it wants and their handlers. It does not need to deal with any of the network I/O functionality at all. --- po/POTFILES.in | 2 + src/Makefile.am | 14 +- src/rpc/virnetclient.c | 1147 +++++++++++++++++++++++++++++++++++++++++ src/rpc/virnetclient.h | 86 +++ src/rpc/virnetclientprogram.c | 342 ++++++++++++ src/rpc/virnetclientprogram.h | 85 +++ src/rpc/virnetclientstream.c | 476 +++++++++++++++++ src/rpc/virnetclientstream.h | 76 +++ 8 files changed, 2227 insertions(+), 1 deletions(-) create mode 100644 src/rpc/virnetclient.c create mode 100644 src/rpc/virnetclient.h create mode 100644 src/rpc/virnetclientprogram.c create mode 100644 src/rpc/virnetclientprogram.h create mode 100644 src/rpc/virnetclientstream.c create mode 100644 src/rpc/virnetclientstream.h diff --git a/po/POTFILES.in b/po/POTFILES.in index c071874..135fbb8 100644 --- a/po/POTFILES.in +++ b/po/POTFILES.in @@ -64,6 +64,8 @@ src/qemu/qemu_monitor_json.c src/qemu/qemu_monitor_text.c src/qemu/qemu_process.c src/remote/remote_driver.c +src/rpc/virnetclient.c +src/rpc/virnetclientprogram.c src/rpc/virnetmessage.c src/rpc/virnetsaslcontext.c src/rpc/virnetsocket.c diff --git a/src/Makefile.am b/src/Makefile.am index 3a724d1..709a762 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -1199,7 +1199,7 @@ libvirt_qemu_la_LIBADD = libvirt.la $(CYGWIN_EXTRA_LIBADD) EXTRA_DIST += $(LIBVIRT_QEMU_SYMBOL_FILE) -noinst_LTLIBRARIES += libvirt-net-rpc.la libvirt-net-rpc-server.la +noinst_LTLIBRARIES += libvirt-net-rpc.la libvirt-net-rpc-server.la libvirt-net-rpc-client.la libvirt_net_rpc_la_SOURCES = \ rpc/virnetmessage.h rpc/virnetmessage.c \ @@ -1249,6 +1249,18 @@ libvirt_net_rpc_server_la_LDFLAGS = \ libvirt_net_rpc_server_la_LIBADD = \ $(CYGWIN_EXTRA_LIBADD) +libvirt_net_rpc_client_la_SOURCES = \ + rpc/virnetclientprogram.h rpc/virnetclientprogram.c \ + rpc/virnetclientstream.h rpc/virnetclientstream.c \ + rpc/virnetclient.h rpc/virnetclient.c +libvirt_net_rpc_client_la_CFLAGS = \ + $(AM_CFLAGS) +libvirt_net_rpc_client_la_LDFLAGS = \ + $(AM_LDFLAGS) \ + $(CYGWIN_EXTRA_LDFLAGS) \ + $(MINGW_EXTRA_LDFLAGS)l +libvirt_net_rpc_client_la_LIBADD = \ + $(CYGWIN_EXTRA_LIBADD) libexec_PROGRAMS = diff --git a/src/rpc/virnetclient.c b/src/rpc/virnetclient.c new file mode 100644 index 0000000..c245cd6 --- /dev/null +++ b/src/rpc/virnetclient.c @@ -0,0 +1,1147 @@ +/* + * virnetclient.c: generic network RPC client + * + * Copyright (C) 2006-2010 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#include <config.h> + +#include <unistd.h> +#include <poll.h> +#include <signal.h> + +#include "virnetclient.h" +#include "virnetsocket.h" +#include "memory.h" +#include "threads.h" +#include "files.h" +#include "logging.h" +#include "util.h" +#include "virterror_internal.h" + +#define VIR_FROM_THIS VIR_FROM_RPC + +#define virNetError(code, ...) \ + virReportErrorHelper(NULL, VIR_FROM_RPC, code, __FILE__, \ + __FUNCTION__, __LINE__, __VA_ARGS__) + +typedef struct _virNetClientCall virNetClientCall; +typedef virNetClientCall *virNetClientCallPtr; + +enum { + VIR_NET_CLIENT_MODE_WAIT_TX, + VIR_NET_CLIENT_MODE_WAIT_RX, + VIR_NET_CLIENT_MODE_COMPLETE, +}; + +struct _virNetClientCall { + int mode; + + virNetMessagePtr msg; + int expectReply; + + virCond cond; + + virNetClientCallPtr next; +}; + + +struct _virNetClient { + int refs; + + virMutex lock; + + virNetSocketPtr sock; + + virNetTLSSessionPtr tls; + char *hostname; + + virNetClientProgramPtr *programs; + size_t nprograms; + + /* For incoming message packets */ + virNetMessage msg; + +#if HAVE_SASL + virNetSASLSessionPtr sasl; +#endif + + /* Self-pipe to wakeup threads waiting in poll() */ + int wakeupSendFD; + int wakeupReadFD; + + /* List of threads currently waiting for dispatch */ + virNetClientCallPtr waitDispatch; + + size_t nstreams; + virNetClientStreamPtr *streams; +}; + + +static void virNetClientLock(virNetClientPtr client) +{ + virMutexLock(&client->lock); +} + + +static void virNetClientUnlock(virNetClientPtr client) +{ + virMutexUnlock(&client->lock); +} + + +static void virNetClientIncomingEvent(virNetSocketPtr sock, + int events, + void *opaque); + +static virNetClientPtr virNetClientNew(virNetSocketPtr sock, + const char *hostname) +{ + virNetClientPtr client; + int wakeupFD[2] = { -1, -1 }; + + if (pipe(wakeupFD) < 0) { + virReportSystemError(errno, "%s", + _("unable to make pipe")); + goto error; + } + + if (VIR_ALLOC(client) < 0) + goto no_memory; + + client->refs = 1; + + if (virMutexInit(&client->lock) < 0) + goto error; + + client->sock = sock; + client->wakeupReadFD = wakeupFD[0]; + client->wakeupSendFD = wakeupFD[1]; + wakeupFD[0] = wakeupFD[1] = -1; + + if (hostname && + !(client->hostname = strdup(hostname))) + goto no_memory; + + /* Set up a callback to listen on the socket data */ + if (virNetSocketAddIOCallback(client->sock, + VIR_EVENT_HANDLE_READABLE, + virNetClientIncomingEvent, + client) < 0) + VIR_DEBUG0("Failed to add event watch, disabling events"); + + return client; + +no_memory: + virReportOOMError(); +error: + VIR_FORCE_CLOSE(wakeupFD[0]); + VIR_FORCE_CLOSE(wakeupFD[1]); + virNetClientFree(client); + return NULL; +} + + +virNetClientPtr virNetClientNewUNIX(const char *path, + bool spawnDaemon, + const char *binary) +{ + virNetSocketPtr sock; + + if (virNetSocketNewConnectUNIX(path, spawnDaemon, binary, &sock) < 0) + return NULL; + + return virNetClientNew(sock, NULL); +} + + +virNetClientPtr virNetClientNewTCP(const char *nodename, + const char *service) +{ + virNetSocketPtr sock; + + if (virNetSocketNewConnectTCP(nodename, service, &sock) < 0) + return NULL; + + return virNetClientNew(sock, nodename); +} + +virNetClientPtr virNetClientNewSSH(const char *nodename, + const char *service, + const char *binary, + const char *username, + bool noTTY, + const char *netcat, + const char *path) +{ + virNetSocketPtr sock; + + if (virNetSocketNewConnectSSH(nodename, service, binary, username, noTTY, netcat, path, &sock) < 0) + return NULL; + + return virNetClientNew(sock, NULL); +} + +virNetClientPtr virNetClientNewExternal(const char **cmdargv) +{ + virNetSocketPtr sock; + + if (virNetSocketNewConnectExternal(cmdargv, &sock) < 0) + return NULL; + + return virNetClientNew(sock, NULL); +} + + +void virNetClientRef(virNetClientPtr client) +{ + virNetClientLock(client); + client->refs++; + virNetClientUnlock(client); +} + + +void virNetClientFree(virNetClientPtr client) +{ + int i; + + if (!client) + return; + + virNetClientLock(client); + client->refs--; + if (client->refs > 0) { + virNetClientUnlock(client); + return; + } + + for (i = 0 ; i < client->nprograms ; i++) + virNetClientProgramFree(client->programs[i]); + VIR_FREE(client->programs); + + VIR_FORCE_CLOSE(client->wakeupSendFD); + VIR_FORCE_CLOSE(client->wakeupReadFD); + + VIR_FREE(client->hostname); + + virNetSocketRemoveIOCallback(client->sock); + virNetSocketFree(client->sock); + virNetTLSSessionFree(client->tls); +#if HAVE_SASL + virNetSASLSessionFree(client->sasl); +#endif + virNetClientUnlock(client); + virMutexDestroy(&client->lock); + + VIR_FREE(client); +} + + +#if HAVE_SASL +void virNetClientSetSASLSession(virNetClientPtr client, + virNetSASLSessionPtr sasl) +{ + virNetClientLock(client); + client->sasl = sasl; + virNetSASLSessionRef(sasl); + virNetSocketSetSASLSession(client->sock, client->sasl); + virNetClientUnlock(client); +} +#endif + + +int virNetClientSetTLSSession(virNetClientPtr client, + virNetTLSContextPtr tls) +{ + int ret; + char buf[1]; + int len; + struct pollfd fds[1]; +#ifdef HAVE_PTHREAD_SIGMASK + sigset_t oldmask, blockedsigs; + + sigemptyset (&blockedsigs); + sigaddset (&blockedsigs, SIGWINCH); + sigaddset (&blockedsigs, SIGCHLD); + sigaddset (&blockedsigs, SIGPIPE); +#endif + + virNetClientLock(client); + + if (!(client->tls = virNetTLSSessionNew(tls, + client->hostname))) + goto error; + + virNetSocketSetTLSSession(client->sock, client->tls); + + for (;;) { + ret = virNetTLSSessionHandshake(client->tls); + + if (ret < 0) + goto error; + if (ret == 0) + break; + + fds[0].fd = virNetSocketGetFD(client->sock); + fds[0].revents = 0; + if (virNetTLSSessionGetHandshakeStatus(client->tls) == + VIR_NET_TLS_HANDSHAKE_RECVING) + fds[0].events = POLLIN; + else + fds[0].events = POLLOUT; + + /* Block SIGWINCH from interrupting poll in curses programs, + * then restore the original signal mask again immediately + * after the call (RHBZ#567931). Same for SIGCHLD and SIGPIPE + * at the suggestion of Paolo Bonzini and Daniel Berrange. + */ +#ifdef HAVE_PTHREAD_SIGMASK + ignore_value(pthread_sigmask(SIG_BLOCK, &blockedsigs, &oldmask)); +#endif + + repoll: + ret = poll(fds, ARRAY_CARDINALITY(fds), -1); + if (ret < 0 && errno == EAGAIN) + goto repoll; + +#ifdef HAVE_PTHREAD_SIGMASK + ignore_value(pthread_sigmask(SIG_BLOCK, &oldmask, NULL)); +#endif + } + + ret = virNetTLSContextCheckCertificate(tls, client->tls); + + if (ret < 0) + goto error; + + /* At this point, the server is verifying _our_ certificate, IP address, + * etc. If we make the grade, it will send us a '\1' byte. + */ + + fds[0].fd = virNetSocketGetFD(client->sock); + fds[0].revents = 0; + fds[0].events = POLLIN; + +#ifdef HAVE_PTHREAD_SIGMASK + /* Block SIGWINCH from interrupting poll in curses programs */ + ignore_value(pthread_sigmask(SIG_BLOCK, &blockedsigs, &oldmask)); +#endif + + repoll2: + ret = poll(fds, ARRAY_CARDINALITY(fds), -1); + if (ret < 0 && errno == EAGAIN) + goto repoll2; + +#ifdef HAVE_PTHREAD_SIGMASK + ignore_value(pthread_sigmask(SIG_BLOCK, &oldmask, NULL)); +#endif + + len = virNetTLSSessionRead(client->tls, buf, 1); + if (len < 0) { + virReportSystemError(errno, "%s", + _("Unable to read TLS confirmation")); + goto error; + } + if (len != 1 || buf[0] != '\1') { + virNetError(VIR_ERR_RPC, "%s", + _("server verification (of our certificate or IP " + "address) failed")); + goto error; + } + + virNetClientUnlock(client); + return 0; + +error: + virNetTLSSessionFree(client->tls); + client->tls = NULL; + virNetClientUnlock(client); + return -1; +} + +bool virNetClientIsEncrypted(virNetClientPtr client) +{ + bool ret = false; + virNetClientLock(client); + if (client->tls) + ret = true; +#if HAVE_SASL + if (client->sasl) + ret = true; +#endif + virNetClientUnlock(client); + return ret; +} + + +int virNetClientAddProgram(virNetClientPtr client, + virNetClientProgramPtr prog) +{ + virNetClientLock(client); + + if (VIR_EXPAND_N(client->programs, client->nprograms, 1) < 0) + goto no_memory; + + client->programs[client->nprograms-1] = prog; + virNetClientProgramRef(prog); + + virNetClientUnlock(client); + return 0; + +no_memory: + virReportOOMError(); + virNetClientUnlock(client); + return -1; +} + + +int virNetClientAddStream(virNetClientPtr client, + virNetClientStreamPtr st) +{ + virNetClientLock(client); + + if (VIR_EXPAND_N(client->streams, client->nstreams, 1) < 0) + goto no_memory; + + client->streams[client->nstreams-1] = st; + virNetClientStreamRef(st); + + virNetClientUnlock(client); + return 0; + +no_memory: + virReportOOMError(); + virNetClientUnlock(client); + return -1; +} + + +void virNetClientRemoveStream(virNetClientPtr client, + virNetClientStreamPtr st) +{ + virNetClientLock(client); + size_t i; + for (i = 0 ; i < client->nstreams ; i++) { + if (client->streams[i] == st) + break; + } + if (i == client->nstreams) + goto cleanup; + + if (client->nstreams > 1) { + memmove(client->streams + i, + client->streams + i + 1, + sizeof(*client->streams) * + (client->nstreams - (i + 1))); + VIR_SHRINK_N(client->streams, client->nstreams, 1); + } else { + VIR_FREE(client->streams); + client->nstreams = 0; + } + virNetClientStreamFree(st); + +cleanup: + virNetClientUnlock(client); +} + + +const char *virNetClientLocalAddrString(virNetClientPtr client) +{ + return virNetSocketLocalAddrString(client->sock); +} + +const char *virNetClientRemoteAddrString(virNetClientPtr client) +{ + return virNetSocketRemoteAddrString(client->sock); +} + +int virNetClientGetTLSKeySize(virNetClientPtr client) +{ + int ret = 0; + virNetClientLock(client); + if (client->tls) + ret = virNetTLSSessionGetKeySize(client->tls); + virNetClientUnlock(client); + return ret; +} + +static int +virNetClientCallDispatchReply(virNetClientPtr client) +{ + virNetClientCallPtr thecall; + + /* Ok, definitely got an RPC reply now find + out who's been waiting for it */ + thecall = client->waitDispatch; + while (thecall && + !(thecall->msg->header.prog == client->msg.header.prog && + thecall->msg->header.vers == client->msg.header.vers && + thecall->msg->header.serial == client->msg.header.serial)) + thecall = thecall->next; + + if (!thecall) { + virNetError(VIR_ERR_RPC, + _("no call waiting for reply with prog %d vers %d serial %d"), + client->msg.header.prog, client->msg.header.vers, client->msg.header.serial); + return -1; + } + + memcpy(thecall->msg->buffer, client->msg.buffer, sizeof(client->msg.buffer)); + memcpy(&thecall->msg->header, &client->msg.header, sizeof(client->msg.header)); + thecall->msg->bufferLength = client->msg.bufferLength; + thecall->msg->bufferOffset = client->msg.bufferOffset; + + thecall->mode = VIR_NET_CLIENT_MODE_COMPLETE; + + return 0; +} + +static int virNetClientCallDispatchMessage(virNetClientPtr client) +{ + size_t i; + virNetClientProgramPtr prog = NULL; + + for (i = 0 ; i < client->nprograms ; i++) { + if (virNetClientProgramMatches(client->programs[i], + &client->msg)) { + prog = client->programs[i]; + break; + } + } + if (!prog) { + VIR_DEBUG("No program found for event with prog=%d vers=%d", + client->msg.header.prog, client->msg.header.vers); + return -1; + } + + virNetClientProgramDispatch(prog, client, &client->msg); + + return 0; +} + +static int virNetClientCallDispatchStream(virNetClientPtr client) +{ + size_t i; + virNetClientStreamPtr st = NULL; + virNetClientCallPtr thecall; + + /* First identify what stream this packet is directed at */ + for (i = 0 ; i < client->nstreams ; i++) { + if (virNetClientStreamMatches(client->streams[i], + &client->msg)) { + st = client->streams[i]; + break; + } + } + if (!st) { + VIR_DEBUG("No stream found for packet with prog=%d vers=%d serial=%u proc=%u", + client->msg.header.prog, client->msg.header.vers, + client->msg.header.serial, client->msg.header.proc); + return -1; + } + + /* Finish/Abort are synchronous, so also see if there's an + * (optional) call waiting for this stream packet */ + thecall = client->waitDispatch; + while (thecall && + !(thecall->msg->header.prog == client->msg.header.prog && + thecall->msg->header.vers == client->msg.header.vers && + thecall->msg->header.serial == client->msg.header.serial)) + thecall = thecall->next; + + VIR_DEBUG("Found call %p", thecall); + + /* Status is either + * - REMOTE_OK - no payload for streams + * - REMOTE_ERROR - followed by a remote_error struct + * - REMOTE_CONTINUE - followed by a raw data packet + */ + switch (client->msg.header.status) { + case VIR_NET_CONTINUE: { + if (virNetClientStreamQueuePacket(st, &client->msg) < 0) + return -1; + + if (thecall && thecall->expectReply) { + VIR_DEBUG0("Got sync data packet completion"); + thecall->mode = VIR_NET_CLIENT_MODE_COMPLETE; + } else { + // XXX + //remoteStreamEventTimerUpdate(privst); + } + return 0; + } + + case VIR_NET_OK: + if (thecall && thecall->expectReply) { + VIR_DEBUG0("Got a synchronous confirm"); + thecall->mode = VIR_NET_CLIENT_MODE_COMPLETE; + } else { + VIR_DEBUG0("Got unexpected async stream finish confirmation"); + return -1; + } + return 0; + + case VIR_NET_ERROR: + /* No call, so queue the error against the stream */ + if (virNetClientStreamSetError(st, &client->msg) < 0) + return -1; + + if (thecall && thecall->expectReply) { + VIR_DEBUG0("Got a synchronous error"); + /* Raise error now, so that this call will see it immediately */ + virNetClientStreamRaiseError(st); + thecall->mode = VIR_NET_CLIENT_MODE_COMPLETE; + } + return 0; + + default: + VIR_WARN("Stream with unexpected serial=%d, proc=%d, status=%d", + client->msg.header.serial, client->msg.header.proc, + client->msg.header.status); + return -1; + } + + return 0; +} + + +static int +virNetClientCallDispatch(virNetClientPtr client) +{ + if (virNetMessageDecodeHeader(&client->msg) < 0) + return -1; + + switch (client->msg.header.type) { + case VIR_NET_REPLY: /* Normal RPC replies */ + return virNetClientCallDispatchReply(client); + + case VIR_NET_MESSAGE: /* Async notifications */ + return virNetClientCallDispatchMessage(client); + + case VIR_NET_STREAM: /* Stream protocol */ + return virNetClientCallDispatchStream(client); + + default: + virNetError(VIR_ERR_RPC, + _("got unexpected RPC call prog %d vers %d proc %d type %d"), + client->msg.header.prog, client->msg.header.vers, + client->msg.header.proc, client->msg.header.type); + return -1; + } +} + + +static ssize_t +virNetClientIOWriteMessage(virNetClientPtr client, + virNetClientCallPtr thecall) +{ + ssize_t ret; + + ret = virNetSocketWrite(client->sock, + thecall->msg->buffer + thecall->msg->bufferOffset, + thecall->msg->bufferLength - thecall->msg->bufferOffset); + if (ret <= 0) + return ret; + + thecall->msg->bufferOffset += ret; + + if (thecall->msg->bufferOffset == thecall->msg->bufferLength) { + thecall->msg->bufferOffset = thecall->msg->bufferLength = 0; + if (thecall->expectReply) + thecall->mode = VIR_NET_CLIENT_MODE_WAIT_RX; + else + thecall->mode = VIR_NET_CLIENT_MODE_COMPLETE; + } + + return ret; +} + + +static ssize_t +virNetClientIOHandleOutput(virNetClientPtr client) +{ + virNetClientCallPtr thecall = client->waitDispatch; + + while (thecall && + thecall->mode != VIR_NET_CLIENT_MODE_WAIT_TX) + thecall = thecall->next; + + if (!thecall) + return -1; /* Shouldn't happen, but you never know... */ + + while (thecall) { + ssize_t ret = virNetClientIOWriteMessage(client, thecall); + if (ret < 0) + return ret; + + if (thecall->mode == VIR_NET_CLIENT_MODE_WAIT_TX) + return 0; /* Blocking write, to back to event loop */ + + thecall = thecall->next; + } + + return 0; /* No more calls to send, all done */ +} + +static ssize_t +virNetClientIOReadMessage(virNetClientPtr client) +{ + size_t wantData; + ssize_t ret; + + /* Start by reading length word */ + if (client->msg.bufferLength == 0) + client->msg.bufferLength = 4; + + wantData = client->msg.bufferLength - client->msg.bufferOffset; + + ret = virNetSocketRead(client->sock, + client->msg.buffer + client->msg.bufferOffset, + wantData); + if (ret <= 0) + return ret; + + client->msg.bufferOffset += ret; + + return ret; +} + + +static ssize_t +virNetClientIOHandleInput(virNetClientPtr client) +{ + /* Read as much data as is available, until we get + * EAGAIN + */ + for (;;) { + ssize_t ret = virNetClientIOReadMessage(client); + + if (ret < 0) + return -1; + if (ret == 0) + return 0; /* Blocking on read */ + + /* Check for completion of our goal */ + if (client->msg.bufferOffset == client->msg.bufferLength) { + if (client->msg.bufferOffset == 4) { + ret = virNetMessageDecodeLength(&client->msg); + if (ret < 0) + return -1; + + /* + * We'll carry on around the loop to immediately + * process the message body, because it has probably + * already arrived. Worst case, we'll get EAGAIN on + * next iteration. + */ + } else { + ret = virNetClientCallDispatch(client); + client->msg.bufferOffset = client->msg.bufferLength = 0; + /* + * We've completed one call, so return even + * though there might still be more data on + * the wire. We need to actually let the caller + * deal with this arrived message to keep good + * response, and also to correctly handle EOF. + */ + return ret; + } + } + } +} + + +/* + * Process all calls pending dispatch/receive until we + * get a reply to our own call. Then quit and pass the buck + * to someone else. + */ +static int virNetClientIOEventLoop(virNetClientPtr client, + virNetClientCallPtr thiscall) +{ + struct pollfd fds[2]; + int ret; + + fds[0].fd = virNetSocketGetFD(client->sock); + fds[1].fd = client->wakeupReadFD; + + for (;;) { + virNetClientCallPtr tmp = client->waitDispatch; + virNetClientCallPtr prev; + char ignore; +#ifdef HAVE_PTHREAD_SIGMASK + sigset_t oldmask, blockedsigs; +#endif + int timeout = -1; + + /* If we have existing SASL decoded data we + * don't want to sleep in the poll(), just + * check if any other FDs are also ready + */ + if (virNetSocketHasCachedData(client->sock)) + timeout = 0; + + fds[0].events = fds[0].revents = 0; + fds[1].events = fds[1].revents = 0; + + fds[1].events = POLLIN; + while (tmp) { + if (tmp->mode == VIR_NET_CLIENT_MODE_WAIT_RX) + fds[0].events |= POLLIN; + if (tmp->mode == VIR_NET_CLIENT_MODE_WAIT_TX) + fds[0].events |= POLLOUT; + + tmp = tmp->next; + } + + /* We have to be prepared to receive stream data + * regardless of whether any of the calls waiting + * for dispatch are for streams. + */ + if (client->nstreams) + fds[0].events |= POLLIN; + + /* Release lock while poll'ing so other threads + * can stuff themselves on the queue */ + virNetClientUnlock(client); + + /* Block SIGWINCH from interrupting poll in curses programs, + * then restore the original signal mask again immediately + * after the call (RHBZ#567931). Same for SIGCHLD and SIGPIPE + * at the suggestion of Paolo Bonzini and Daniel Berrange. + */ +#ifdef HAVE_PTHREAD_SIGMASK + sigemptyset (&blockedsigs); + sigaddset (&blockedsigs, SIGWINCH); + sigaddset (&blockedsigs, SIGCHLD); + sigaddset (&blockedsigs, SIGPIPE); + ignore_value(pthread_sigmask(SIG_BLOCK, &blockedsigs, &oldmask)); +#endif + + repoll: + ret = poll(fds, ARRAY_CARDINALITY(fds), timeout); + if (ret < 0 && errno == EAGAIN) + goto repoll; + +#ifdef HAVE_PTHREAD_SIGMASK + ignore_value(pthread_sigmask(SIG_SETMASK, &oldmask, NULL)); +#endif + + virNetClientLock(client); + + /* If we have existing SASL decoded data, pretend + * the socket became readable so we consume it + */ + if (virNetSocketHasCachedData(client->sock)) + fds[0].revents |= POLLIN; + + if (fds[1].revents) { + VIR_DEBUG0("Woken up from poll by other thread"); + if (saferead(client->wakeupReadFD, &ignore, sizeof(ignore)) != sizeof(ignore)) { + virReportSystemError(errno, "%s", + _("read on wakeup fd failed")); + goto error; + } + } + + if (ret < 0) { + if (errno == EWOULDBLOCK) + continue; + virReportSystemError(errno, + "%s", _("poll on socket failed")); + goto error; + } + + if (fds[0].revents & POLLOUT) { + if (virNetClientIOHandleOutput(client) < 0) + goto error; + } + + if (fds[0].revents & POLLIN) { + if (virNetClientIOHandleInput(client) < 0) + goto error; + } + + /* Iterate through waiting threads and if + * any are complete then tell 'em to wakeup + */ + tmp = client->waitDispatch; + prev = NULL; + while (tmp) { + if (tmp != thiscall && + tmp->mode == VIR_NET_CLIENT_MODE_COMPLETE) { + /* Take them out of the list */ + if (prev) + prev->next = tmp->next; + else + client->waitDispatch = tmp->next; + + /* And wake them up.... + * ...they won't actually wakeup until + * we release our mutex a short while + * later... + */ + VIR_DEBUG("Waking up sleep %p %p", tmp, client->waitDispatch); + virCondSignal(&tmp->cond); + } + prev = tmp; + tmp = tmp->next; + } + + /* Now see if *we* are done */ + if (thiscall->mode == VIR_NET_CLIENT_MODE_COMPLETE) { + /* We're at head of the list already, so + * remove us + */ + client->waitDispatch = thiscall->next; + VIR_DEBUG("Giving up the buck %p %p", thiscall, client->waitDispatch); + /* See if someone else is still waiting + * and if so, then pass the buck ! */ + if (client->waitDispatch) { + VIR_DEBUG("Passing the buck to %p", client->waitDispatch); + virCondSignal(&client->waitDispatch->cond); + } + return 0; + } + + + if (fds[0].revents & (POLLHUP | POLLERR)) { + virNetError(VIR_ERR_INTERNAL_ERROR, "%s", + _("received hangup / error event on socket")); + goto error; + } + } + + +error: + client->waitDispatch = thiscall->next; + VIR_DEBUG("Giving up the buck due to I/O error %p %p", thiscall, client->waitDispatch); + /* See if someone else is still waiting + * and if so, then pass the buck ! */ + if (client->waitDispatch) { + VIR_DEBUG("Passing the buck to %p", client->waitDispatch); + virCondSignal(&client->waitDispatch->cond); + } + return -1; +} + + +/* + * This function sends a message to remote server and awaits a reply + * + * NB. This does not free the args structure (not desirable, since you + * often want this allocated on the stack or else it contains strings + * which come from the user). It does however free any intermediate + * results, eg. the error structure if there is one. + * + * NB(2). Make sure to memset (&ret, 0, sizeof ret) before calling, + * else Bad Things will happen in the XDR code. + * + * NB(3) You must have the client lock before calling this + * + * NB(4) This is very complicated. Multiple threads are allowed to + * use the client for RPC at the same time. Obviously only one of + * them can. So if someone's using the socket, other threads are put + * to sleep on condition variables. The existing thread may completely + * send & receive their RPC call/reply while they're asleep. Or it + * may only get around to dealing with sending the call. Or it may + * get around to neither. So upon waking up from slumber, the other + * thread may or may not have more work todo. + * + * We call this dance 'passing the buck' + * + * http://en.wikipedia.org/wiki/Passing_the_buck + * + * "Buck passing or passing the buck is the action of transferring + * responsibility or blame unto another person. It is also used as + * a strategy in power politics when the actions of one country/ + * nation are blamed on another, providing an opportunity for war." + * + * NB(5) Don't Panic! + */ +static int virNetClientIO(virNetClientPtr client, + virNetClientCallPtr thiscall) +{ + int rv = -1; + + VIR_DEBUG("program=%u version=%u serial=%u proc=%d type=%d length=%zu dispatch=%p", + thiscall->msg->header.prog, + thiscall->msg->header.vers, + thiscall->msg->header.serial, + thiscall->msg->header.proc, + thiscall->msg->header.type, + thiscall->msg->bufferLength, + client->waitDispatch); + + /* Check to see if another thread is dispatching */ + if (client->waitDispatch) { + /* Stick ourselves on the end of the wait queue */ + virNetClientCallPtr tmp = client->waitDispatch; + char ignore = 1; + while (tmp && tmp->next) + tmp = tmp->next; + if (tmp) + tmp->next = thiscall; + else + client->waitDispatch = thiscall; + + /* Force other thread to wakeup from poll */ + if (safewrite(client->wakeupSendFD, &ignore, sizeof(ignore)) != sizeof(ignore)) { + if (tmp) + tmp->next = NULL; + else + client->waitDispatch = NULL; + virReportSystemError(errno, "%s", + _("failed to wake up polling thread")); + return -1; + } + + VIR_DEBUG("Going to sleep %p %p", client->waitDispatch, thiscall); + /* Go to sleep while other thread is working... */ + if (virCondWait(&thiscall->cond, &client->lock) < 0) { + if (client->waitDispatch == thiscall) { + client->waitDispatch = thiscall->next; + } else { + tmp = client->waitDispatch; + while (tmp && tmp->next && + tmp->next != thiscall) { + tmp = tmp->next; + } + if (tmp && tmp->next == thiscall) + tmp->next = thiscall->next; + } + virNetError(VIR_ERR_INTERNAL_ERROR, "%s", + _("failed to wait on condition")); + return -1; + } + + VIR_DEBUG("Wokeup from sleep %p %p", client->waitDispatch, thiscall); + /* Two reasons we can be woken up + * 1. Other thread has got our reply ready for us + * 2. Other thread is all done, and it is our turn to + * be the dispatcher to finish waiting for + * our reply + */ + if (thiscall->mode == VIR_NET_CLIENT_MODE_COMPLETE) { + rv = 0; + /* + * We avoided catching the buck and our reply is ready ! + * We've already had 'thiscall' removed from the list + * so just need to (maybe) handle errors & free it + */ + goto cleanup; + } + + /* Grr, someone passed the buck onto us ... */ + + } else { + /* We're first to catch the buck */ + client->waitDispatch = thiscall; + } + + VIR_DEBUG("We have the buck %p %p", client->waitDispatch, thiscall); + /* + * The buck stops here! + * + * At this point we're about to own the dispatch + * process... + */ + + /* + * Avoid needless wake-ups of the event loop in the + * case where this call is being made from a different + * thread than the event loop. These wake-ups would + * cause the event loop thread to be blocked on the + * mutex for the duration of the call + */ + virNetSocketUpdateIOCallback(client->sock, 0); + + rv = virNetClientIOEventLoop(client, thiscall); + + virNetSocketUpdateIOCallback(client->sock, VIR_EVENT_HANDLE_READABLE); + +cleanup: + VIR_DEBUG("All done with our call %p %p %d", client->waitDispatch, thiscall, rv); + return rv; +} + + +void virNetClientIncomingEvent(virNetSocketPtr sock, + int events, + void *opaque) +{ + virNetClientPtr client = opaque; + + virNetClientLock(client); + + /* This should be impossible, but it doesn't hurt to check */ + if (client->waitDispatch) + goto done; + + VIR_DEBUG("Event fired %p %d", sock, events); + + if (events & (VIR_EVENT_HANDLE_HANGUP | VIR_EVENT_HANDLE_ERROR)) { + VIR_DEBUG("%s : VIR_EVENT_HANDLE_HANGUP or " + "VIR_EVENT_HANDLE_ERROR encountered", __FUNCTION__); + virNetSocketRemoveIOCallback(sock); + goto done; + } + + if (virNetClientIOHandleInput(client) < 0) + VIR_DEBUG0("Something went wrong during async message processing"); + +done: + virNetClientUnlock(client); +} + + +int virNetClientSend(virNetClientPtr client, + virNetMessagePtr msg, + bool expectReply) +{ + virNetClientCallPtr call; + int ret = -1; + + if (VIR_ALLOC(call) < 0) { + virReportOOMError(); + return -1; + } + + virNetClientLock(client); + + if (virCondInit(&call->cond) < 0) { + virNetError(VIR_ERR_INTERNAL_ERROR, "%s", + _("cannot initialize condition variable")); + goto cleanup; + } + + if (msg->bufferLength) + call->mode = VIR_NET_CLIENT_MODE_WAIT_TX; + else + call->mode = VIR_NET_CLIENT_MODE_WAIT_RX; + call->msg = msg; + call->expectReply = expectReply; + + ret = virNetClientIO(client, call); + +cleanup: + VIR_FREE(call); + virNetClientUnlock(client); + return ret; +} diff --git a/src/rpc/virnetclient.h b/src/rpc/virnetclient.h new file mode 100644 index 0000000..8029c08 --- /dev/null +++ b/src/rpc/virnetclient.h @@ -0,0 +1,86 @@ +/* + * virnetclient.h: generic network RPC client + * + * Copyright (C) 2006-2010 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#ifndef __VIR_NET_CLIENT_H__ +# define __VIR_NET_CLIENT_H__ + +# include <stdbool.h> + +# include "virnettlscontext.h" +# include "virnetmessage.h" +# ifdef HAVE_SASL +# include "virnetsaslcontext.h" +# endif +# include "virnetclientprogram.h" +# include "virnetclientstream.h" + + +virNetClientPtr virNetClientNewUNIX(const char *path, + bool spawnDaemon, + const char *daemon); + +virNetClientPtr virNetClientNewTCP(const char *nodename, + const char *service); + +virNetClientPtr virNetClientNewSSH(const char *nodename, + const char *service, + const char *binary, + const char *username, + bool noTTY, + const char *netcat, + const char *path); + +virNetClientPtr virNetClientNewExternal(const char **cmdargv); + +void virNetClientRef(virNetClientPtr client); + +int virNetClientAddProgram(virNetClientPtr client, + virNetClientProgramPtr prog); + +int virNetClientAddStream(virNetClientPtr client, + virNetClientStreamPtr st); + +void virNetClientRemoveStream(virNetClientPtr client, + virNetClientStreamPtr st); + +int virNetClientSend(virNetClientPtr client, + virNetMessagePtr msg, + bool expectReply); + +# ifdef HAVE_SASL +void virNetClientSetSASLSession(virNetClientPtr client, + virNetSASLSessionPtr sasl); +# endif + +int virNetClientSetTLSSession(virNetClientPtr client, + virNetTLSContextPtr tls); + +bool virNetClientIsEncrypted(virNetClientPtr client); + +const char *virNetClientLocalAddrString(virNetClientPtr client); +const char *virNetClientRemoteAddrString(virNetClientPtr client); + +int virNetClientGetTLSKeySize(virNetClientPtr client); + +void virNetClientFree(virNetClientPtr client); + +#endif /* __VIR_NET_CLIENT_H__ */ diff --git a/src/rpc/virnetclientprogram.c b/src/rpc/virnetclientprogram.c new file mode 100644 index 0000000..12a4cb0 --- /dev/null +++ b/src/rpc/virnetclientprogram.c @@ -0,0 +1,342 @@ +/* + * virnetclientprogram.c: generic network RPC client program + * + * Copyright (C) 2006-2010 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#include <config.h> + +#include "virnetclientprogram.h" +#include "virnetclient.h" +#include "virnetprotocol.h" + +#include "memory.h" +#include "virterror_internal.h" +#include "logging.h" + +#define VIR_FROM_THIS VIR_FROM_RPC + +#define virNetError(code, ...) \ + virReportErrorHelper(NULL, VIR_FROM_RPC, code, __FILE__, \ + __FUNCTION__, __LINE__, __VA_ARGS__) + +struct _virNetClientProgram { + int refs; + + unsigned program; + unsigned version; + virNetClientProgramEventPtr events; + size_t nevents; + void *eventOpaque; +}; + +virNetClientProgramPtr virNetClientProgramNew(unsigned program, + unsigned version, + virNetClientProgramEventPtr events, + size_t nevents, + void *eventOpaque) +{ + virNetClientProgramPtr prog; + + if (VIR_ALLOC(prog) < 0) { + virReportOOMError(); + return NULL; + } + + prog->refs = 1; + prog->program = program; + prog->version = version; + prog->events = events; + prog->nevents = nevents; + prog->eventOpaque = eventOpaque; + + return prog; +} + + +void virNetClientProgramRef(virNetClientProgramPtr prog) +{ + prog->refs++; +} + + +void virNetClientProgramFree(virNetClientProgramPtr prog) +{ + if (!prog) + return; + + prog->refs--; + if (prog->refs > 0) + return; + + VIR_FREE(prog); +} + + +unsigned virNetClientProgramGetProgram(virNetClientProgramPtr prog) +{ + return prog->program; +} + + +unsigned virNetClientProgramGetVersion(virNetClientProgramPtr prog) +{ + return prog->version; +} + + +int virNetClientProgramMatches(virNetClientProgramPtr prog, + virNetMessagePtr msg) +{ + if (prog->program == msg->header.prog && + prog->version == msg->header.vers) + return 1; + return 0; +} + + +static int +virNetClientProgramDispatchError(virNetClientProgramPtr prog ATTRIBUTE_UNUSED, + virNetMessagePtr msg) +{ + virNetMessageError err; + int ret = -1; + + memset(&err, 0, sizeof(err)); + + if (virNetMessageDecodePayload(msg, (xdrproc_t)xdr_virNetMessageError, &err) < 0) + goto cleanup; + + /* Interop for virErrorNumber glitch in 0.8.0, if server is + * 0.7.1 through 0.7.7; see comments in virterror.h. */ + switch (err.code) { + case VIR_WAR_NO_NWFILTER: + /* no way to tell old VIR_WAR_NO_SECRET apart from + * VIR_WAR_NO_NWFILTER, but both are very similar + * warnings, so ignore the difference */ + break; + case VIR_ERR_INVALID_NWFILTER: + case VIR_ERR_NO_NWFILTER: + case VIR_ERR_BUILD_FIREWALL: + /* server was trying to pass VIR_ERR_INVALID_SECRET, + * VIR_ERR_NO_SECRET, or VIR_ERR_CONFIG_UNSUPPORTED */ + if (err.domain != VIR_FROM_NWFILTER) + err.code += 4; + break; + case VIR_WAR_NO_SECRET: + if (err.domain == VIR_FROM_QEMU) + err.code = VIR_ERR_OPERATION_TIMEOUT; + break; + case VIR_ERR_INVALID_SECRET: + if (err.domain == VIR_FROM_XEN) + err.code = VIR_ERR_MIGRATE_PERSIST_FAILED; + break; + default: + /* Nothing to alter. */ + break; + } + + if (err.domain == VIR_FROM_REMOTE && + err.code == VIR_ERR_RPC && + err.level == VIR_ERR_ERROR && + err.message && + STRPREFIX(*err.message, "unknown procedure")) { + virRaiseErrorFull(NULL, + __FILE__, __FUNCTION__, __LINE__, + err.domain, + VIR_ERR_NO_SUPPORT, + err.level, + err.str1 ? *err.str1 : NULL, + err.str2 ? *err.str2 : NULL, + err.str3 ? *err.str3 : NULL, + err.int1, + err.int2, + "%s", *err.message); + } else { + virRaiseErrorFull(NULL, + __FILE__, __FUNCTION__, __LINE__, + err.domain, + err.code, + err.level, + err.str1 ? *err.str1 : NULL, + err.str2 ? *err.str2 : NULL, + err.str3 ? *err.str3 : NULL, + err.int1, + err.int2, + "%s", err.message ? *err.message : _("Unknown error")); + } + + ret = 0; + +cleanup: + xdr_free((xdrproc_t)xdr_virNetMessageError, (void*)&err); + return ret; +} + + +static virNetClientProgramEventPtr virNetClientProgramGetEvent(virNetClientProgramPtr prog, + int procedure) +{ + int i; + + for (i = 0 ; i < prog->nevents ; i++) { + if (prog->events[i].proc == procedure) + return &prog->events[i]; + } + + return NULL; +} + + +int virNetClientProgramDispatch(virNetClientProgramPtr prog, + virNetClientPtr client, + virNetMessagePtr msg) +{ + virNetClientProgramEventPtr event; + char *evdata; + + VIR_DEBUG("prog=%d ver=%d type=%d status=%d serial=%d proc=%d", + msg->header.prog, msg->header.vers, msg->header.type, + msg->header.status, msg->header.serial, msg->header.proc); + + /* Check version, etc. */ + if (msg->header.prog != prog->program) { + VIR_ERROR(_("program mismatch in event (actual %x, expected %x)"), + msg->header.prog, prog->program); + return -1; + } + + if (msg->header.vers != prog->version) { + VIR_ERROR(_("version mismatch in event (actual %x, expected %x)"), + msg->header.vers, prog->version); + return -1; + } + + if (msg->header.status != VIR_NET_OK) { + VIR_ERROR(_("status mismatch in event (actual %x, expected %x)"), + msg->header.status, VIR_NET_OK); + return -1; + } + + if (msg->header.type != VIR_NET_MESSAGE) { + VIR_ERROR(_("type mismatch in event (actual %x, expected %x)"), + msg->header.type, VIR_NET_MESSAGE); + return -1; + } + + event = virNetClientProgramGetEvent(prog, msg->header.proc); + + if (!event) { + VIR_ERROR(_("No event expected with procedure %x"), + msg->header.proc); + return -1; + } + + if (VIR_ALLOC_N(evdata, event->msg_len) < 0) { + virReportOOMError(); + return -1; + } + + if (virNetMessageDecodePayload(msg, event->msg_filter, evdata) < 0) + goto cleanup; + + event->func(prog, client, &evdata, prog->eventOpaque); + + xdr_free(event->msg_filter, evdata); + +cleanup: + VIR_FREE(evdata); + return 0; +} + + +int virNetClientProgramCall(virNetClientProgramPtr prog, + virNetClientPtr client, + unsigned serial, + int proc, + xdrproc_t args_filter, void *args, + xdrproc_t ret_filter, void *ret) +{ + virNetMessagePtr msg; + + if (!(msg = virNetMessageNew())) + return -1; + + msg->header.prog = prog->program; + msg->header.vers = prog->version; + msg->header.status = VIR_NET_OK; + msg->header.type = VIR_NET_CALL; + msg->header.serial = serial; + msg->header.proc = proc; + + if (virNetMessageEncodeHeader(msg) < 0) + goto error; + + if (virNetMessageEncodePayload(msg, args_filter, args) < 0) + goto error; + + if (virNetClientSend(client, msg, true) < 0) + goto error; + + /* None of these 3 should ever happen here, because + * virNetClientSend should have validated the reply, + * but it doesn't hurt to check again. + */ + if (msg->header.type != VIR_NET_REPLY) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("Unexpected message type %d"), msg->header.type); + goto error; + } + if (msg->header.proc != proc) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("Unexpected message proc %d != %d"), + msg->header.proc, proc); + goto error; + } + if (msg->header.serial != serial) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("Unexpected message serial %d != %d"), + msg->header.serial, serial); + goto error; + } + + switch (msg->header.status) { + case VIR_NET_OK: + if (virNetMessageDecodePayload(msg, ret_filter, ret) < 0) + goto error; + break; + + case VIR_NET_ERROR: + virNetClientProgramDispatchError(prog, msg); + goto error; + + default: + virNetError(VIR_ERR_RPC, + _("Unexpected message status %d"), msg->header.status); + goto error; + } + + VIR_FREE(msg); + + return 0; + +error: + VIR_FREE(msg); + return -1; +} diff --git a/src/rpc/virnetclientprogram.h b/src/rpc/virnetclientprogram.h new file mode 100644 index 0000000..50474c5 --- /dev/null +++ b/src/rpc/virnetclientprogram.h @@ -0,0 +1,85 @@ +/* + * virnetclientprogram.h: generic network RPC client program + * + * Copyright (C) 2006-2010 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#ifndef __VIR_NET_CLIENT_PROGRAM_H__ +# define __VIR_NET_CLIENT_PROGRAM_H__ + +# include <rpc/types.h> +# include <rpc/xdr.h> + +# include "virnetmessage.h" + +typedef struct _virNetClient virNetClient; +typedef virNetClient *virNetClientPtr; + +typedef struct _virNetClientProgram virNetClientProgram; +typedef virNetClientProgram *virNetClientProgramPtr; + +typedef struct _virNetClientProgramEvent virNetClientProgramEvent; +typedef virNetClientProgramEvent *virNetClientProgramEventPtr; + +typedef struct _virNetClientProgramErrorHandler virNetClientProgramErrorHander; +typedef virNetClientProgramErrorHander *virNetClientProgramErrorHanderPtr; + + +typedef void (*virNetClientProgramDispatchFunc)(virNetClientProgramPtr prog, + virNetClientPtr client, + void *msg, + void *opaque); + +struct _virNetClientProgramEvent { + int proc; + virNetClientProgramDispatchFunc func; + size_t msg_len; + xdrproc_t msg_filter; +}; + +virNetClientProgramPtr virNetClientProgramNew(unsigned program, + unsigned version, + virNetClientProgramEventPtr events, + size_t nevents, + void *eventOpaque); + +unsigned virNetClientProgramGetProgram(virNetClientProgramPtr prog); +unsigned virNetClientProgramGetVersion(virNetClientProgramPtr prog); + +void virNetClientProgramRef(virNetClientProgramPtr prog); + +void virNetClientProgramFree(virNetClientProgramPtr prog); + +int virNetClientProgramMatches(virNetClientProgramPtr prog, + virNetMessagePtr msg); + +int virNetClientProgramDispatch(virNetClientProgramPtr prog, + virNetClientPtr client, + virNetMessagePtr msg); + +int virNetClientProgramCall(virNetClientProgramPtr prog, + virNetClientPtr client, + unsigned serial, + int proc, + xdrproc_t args_filter, void *args, + xdrproc_t ret_filter, void *ret); + + + +#endif /* __VIR_NET_CLIENT_PROGRAM_H__ */ diff --git a/src/rpc/virnetclientstream.c b/src/rpc/virnetclientstream.c new file mode 100644 index 0000000..0fbcf1a --- /dev/null +++ b/src/rpc/virnetclientstream.c @@ -0,0 +1,476 @@ +/* + * virnetclientstream.c: generic network RPC client stream + * + * Copyright (C) 2006-2010 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#include <config.h> + +#include "virnetclientstream.h" +#include "virnetclient.h" +#include "memory.h" +#include "virterror_internal.h" +#include "logging.h" +#include "event.h" + +#define VIR_FROM_THIS VIR_FROM_RPC + +#define virNetError(code, ...) \ + virReportErrorHelper(NULL, VIR_FROM_RPC, code, __FILE__, \ + __FUNCTION__, __LINE__, __VA_ARGS__) + +struct _virNetClientStream { + virNetClientProgramPtr prog; + int proc; + unsigned serial; + int refs; + + virError err; + + /* XXX this buffer is unbounded if the client + * app has domain events registered, since packets + * may be read off wire, while app isn't ready to + * recv them. Figure out how to address this some + * time by stopping consuming any incoming data + * off the socket.... + */ + char *incoming; + size_t incomingOffset; + size_t incomingLength; + + + virNetClientStreamEventCallback cb; + void *cbOpaque; + virFreeCallback cbFree; + int cbEvents; + int cbTimer; + int cbDispatch; +}; + + +static void +virNetClientStreamEventTimerUpdate(virNetClientStreamPtr st) +{ + if (!st->cb) + return; + + VIR_DEBUG("Check timer offset=%zu %d", st->incomingOffset, st->cbEvents); + + if ((st->incomingOffset && + (st->cbEvents & VIR_STREAM_EVENT_READABLE)) || + (st->cbEvents & VIR_STREAM_EVENT_WRITABLE)) { + VIR_DEBUG0("Enabling event timer"); + virEventUpdateTimeout(st->cbTimer, 0); + } else { + VIR_DEBUG0("Disabling event timer"); + virEventUpdateTimeout(st->cbTimer, -1); + } +} + + +static void +virNetClientStreamEventTimer(int timer ATTRIBUTE_UNUSED, void *opaque) +{ + virNetClientStreamPtr st = opaque; + int events = 0; + + /* XXX we need a mutex on 'st' to protect this callback */ + + if (st->cb && + (st->cbEvents & VIR_STREAM_EVENT_READABLE) && + st->incomingOffset) + events |= VIR_STREAM_EVENT_READABLE; + if (st->cb && + (st->cbEvents & VIR_STREAM_EVENT_WRITABLE)) + events |= VIR_STREAM_EVENT_WRITABLE; + + VIR_DEBUG("Got Timer dispatch %d %d offset=%zu", events, st->cbEvents, st->incomingOffset); + if (events) { + virNetClientStreamEventCallback cb = st->cb; + void *cbOpaque = st->cbOpaque; + virFreeCallback cbFree = st->cbFree; + + st->cbDispatch = 1; + (cb)(st, events, cbOpaque); + st->cbDispatch = 0; + + if (!st->cb && cbFree) + (cbFree)(cbOpaque); + } +} + + +static void +virNetClientStreamEventTimerFree(void *opaque) +{ + virNetClientStreamPtr st = opaque; + virNetClientStreamFree(st); +} + + +virNetClientStreamPtr virNetClientStreamNew(virNetClientProgramPtr prog, + int proc, + unsigned serial) +{ + virNetClientStreamPtr st; + + if (VIR_ALLOC(st) < 0) { + virReportOOMError(); + return NULL; + } + + virNetClientProgramRef(prog); + + st->refs = 1; + st->prog = prog; + st->proc = proc; + st->serial = serial; + + return st; +} + + +void virNetClientStreamRef(virNetClientStreamPtr st) +{ + st->refs++; +} + +void virNetClientStreamFree(virNetClientStreamPtr st) +{ + st->refs--; + if (st->refs > 0) + return; + + virResetError(&st->err); + VIR_FREE(st->incoming); + virNetClientProgramFree(st->prog); + VIR_FREE(st); +} + +bool virNetClientStreamMatches(virNetClientStreamPtr st, + virNetMessagePtr msg) +{ + if (virNetClientProgramMatches(st->prog, msg) && + st->proc == msg->header.proc && + st->serial == msg->header.serial) + return 1; + return 0; +} + + +bool virNetClientStreamRaiseError(virNetClientStreamPtr st) +{ + if (st->err.code == VIR_ERR_OK) + return false; + + virRaiseErrorFull(NULL, + __FILE__, __FUNCTION__, __LINE__, + st->err.domain, + st->err.code, + st->err.level, + st->err.str1, + st->err.str2, + st->err.str3, + st->err.int1, + st->err.int2, + "%s", st->err.message ? st->err.message : _("Unknown error")); + + return true; +} + + +int virNetClientStreamSetError(virNetClientStreamPtr st, + virNetMessagePtr msg) +{ + virNetMessageError err; + int ret = -1; + + if (st->err.code != VIR_ERR_OK) + VIR_DEBUG("Overwriting existing stream error %s", NULLSTR(st->err.message)); + + virResetError(&st->err); + memset(&err, 0, sizeof(err)); + + if (virNetMessageDecodePayload(msg, (xdrproc_t)xdr_virNetMessageError, &err) < 0) + goto cleanup; + + if (err.domain == VIR_FROM_REMOTE && + err.code == VIR_ERR_RPC && + err.level == VIR_ERR_ERROR && + err.message && + STRPREFIX(*err.message, "unknown procedure")) { + st->err.code = VIR_ERR_NO_SUPPORT; + } else { + st->err.code = err.code; + } + st->err.message = *err.message; + *err.message = NULL; + st->err.domain = err.domain; + st->err.level = err.level; + st->err.str1 = *err.str1; + st->err.str2 = *err.str2; + st->err.str3 = *err.str3; + st->err.int1 = err.int1; + st->err.int2 = err.int2; + + ret = 0; + +cleanup: + xdr_free((xdrproc_t)xdr_virNetMessageError, (void*)&err); + return ret; +} + + +int virNetClientStreamQueuePacket(virNetClientStreamPtr st, + virNetMessagePtr msg) +{ + size_t avail = st->incomingLength - st->incomingOffset; + size_t need = msg->bufferLength - msg->bufferOffset; + + if (need > avail) { + size_t extra = need - avail; + if (VIR_REALLOC_N(st->incoming, + st->incomingLength + extra) < 0) { + VIR_DEBUG0("Out of memory handling stream data"); + return -1; + } + st->incomingLength += extra; + } + + memcpy(st->incoming + st->incomingOffset, + msg->buffer + msg->bufferOffset, + msg->bufferLength - msg->bufferOffset); + st->incomingOffset += (msg->bufferLength - msg->bufferOffset); + + VIR_DEBUG("Stream incoming data offset %zu length %zu", + st->incomingOffset, st->incomingLength); + return 0; +} + + +int virNetClientStreamSendPacket(virNetClientStreamPtr st, + virNetClientPtr client, + int status, + const char *data, + size_t nbytes) +{ + virNetMessagePtr msg; + bool wantReply; + VIR_DEBUG("st=%p status=%d data=%p nbytes=%zu", st, status, data, nbytes); + + if (!(msg = virNetMessageNew())) + return -1; + + msg->header.prog = virNetClientProgramGetProgram(st->prog); + msg->header.vers = virNetClientProgramGetVersion(st->prog); + msg->header.status = status; + msg->header.type = VIR_NET_STREAM; + msg->header.serial = st->serial; + msg->header.proc = st->proc; + + if (virNetMessageEncodeHeader(msg) < 0) + goto error; + + /* Data packets are async fire&forget, but OK/ERROR packets + * need a synchronous confirmation + */ + if (status == VIR_NET_CONTINUE) { + if (virNetMessageEncodePayloadRaw(msg, data, nbytes) < 0) + goto error; + wantReply = false; + } else { + wantReply = true; + } + + if (virNetClientSend(client, msg, wantReply) < 0) + goto error; + +#if 0 + if (wantReply) { + /* None of these 3 should ever happen here, because + * virNetClientSend should have validated the reply, + * but it doesn't hurt to check again. + */ + if (msg->header.type != VIR_NET_STREAM) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("Unexpected message type %d"), msg->header.type); + goto error; + } + if (msg->header.proc != st->proc) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("Unexpected message proc %d != %d"), + msg->header.proc, st->proc); + goto error; + } + if (msg->header.serial != st->serial) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("Unexpected message serial %d != %d"), + msg->header.serial, st->serial); + goto error; + } + + switch (msg->header.status) { + case VIR_NET_OK: + /* No payload we need to care about */ + break; + + case VIR_NET_ERROR: + virNetClientProgramDispatchError(st->prog, msg); + goto error; + + default: + virNetError(VIR_ERR_RPC, + _("Unexpected message status %d"), msg->header.status); + goto error; + } + } +#endif + + return 0; + +error: + VIR_FREE(msg); + return -1; +} + +int virNetClientStreamRecvPacket(virNetClientStreamPtr st, + virNetClientPtr client, + char *data, + size_t nbytes, + bool nonblock) +{ + int rv = -1; + if (!st->incomingOffset) { + virNetMessagePtr msg; + int ret; + + if (nonblock) { + VIR_DEBUG0("Non-blocking mode and no data available"); + rv = -2; + goto cleanup; + } + + if (!(msg = virNetMessageNew())) { + virReportOOMError(); + goto cleanup; + } + + msg->header.serial = st->serial; + msg->header.proc = st->proc; + + ret = virNetClientSend(client, msg, true); + + virNetMessageFree(msg); + + if (ret < 0) + goto cleanup; + } + + VIR_DEBUG("After IO %zu", st->incomingOffset); + if (st->incomingOffset) { + int want = st->incomingOffset; + if (want > nbytes) + want = nbytes; + memcpy(data, st->incoming, want); + if (want < st->incomingOffset) { + memmove(st->incoming, st->incoming + want, st->incomingOffset - want); + st->incomingOffset -= want; + } else { + VIR_FREE(st->incoming); + st->incomingOffset = st->incomingLength = 0; + } + rv = want; + } else { + rv = 0; + } + + virNetClientStreamEventTimerUpdate(st); + +cleanup: + return rv; +} + + +int virNetClientStreamEventAddCallback(virNetClientStreamPtr st, + int events, + virNetClientStreamEventCallback cb, + void *opaque, + virFreeCallback ff) +{ + if (st->cb) { + virNetError(VIR_ERR_INTERNAL_ERROR, + "%s", _("multiple stream callbacks not supported")); + return 1; + } + + virNetClientStreamRef(st); + if ((st->cbTimer = + virEventAddTimeout(-1, + virNetClientStreamEventTimer, + st, + virNetClientStreamEventTimerFree)) < 0) { + virNetClientStreamFree(st); + return -1; + } + + st->cb = cb; + st->cbOpaque = opaque; + st->cbFree = ff; + st->cbEvents = events; + + virNetClientStreamEventTimerUpdate(st); + + return 0; +} + +int virNetClientStreamEventUpdateCallback(virNetClientStreamPtr st, + int events) +{ + if (!st->cb) { + virNetError(VIR_ERR_INTERNAL_ERROR, + "%s", _("no stream callback registered")); + return -1; + } + + st->cbEvents = events; + + virNetClientStreamEventTimerUpdate(st); + + return 0; +} + +int virNetClientStreamEventRemoveCallback(virNetClientStreamPtr st) +{ + if (!st->cb) { + virNetError(VIR_ERR_INTERNAL_ERROR, + "%s", _("no stream callback registered")); + return -1; + } + + if (!st->cbDispatch && + st->cbFree) + (st->cbFree)(st->cbOpaque); + st->cb = NULL; + st->cbOpaque = NULL; + st->cbFree = NULL; + st->cbEvents = 0; + virEventRemoveTimeout(st->cbTimer); + + return 0; +} diff --git a/src/rpc/virnetclientstream.h b/src/rpc/virnetclientstream.h new file mode 100644 index 0000000..d36846c --- /dev/null +++ b/src/rpc/virnetclientstream.h @@ -0,0 +1,76 @@ +/* + * virnetclientstream.h: generic network RPC client stream + * + * Copyright (C) 2006-2010 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#ifndef __VIR_NET_CLIENT_STREAM_H__ +# define __VIR_NET_CLIENT_STREAM_H__ + +# include "virnetclientprogram.h" + +typedef struct _virNetClientStream virNetClientStream; +typedef virNetClientStream *virNetClientStreamPtr; + +typedef void (*virNetClientStreamEventCallback)(virNetClientStreamPtr stream, + int events, void *opaque); + +virNetClientStreamPtr virNetClientStreamNew(virNetClientProgramPtr prog, + int proc, + unsigned serial); + +void virNetClientStreamRef(virNetClientStreamPtr st); + +void virNetClientStreamFree(virNetClientStreamPtr st); + +bool virNetClientStreamRaiseError(virNetClientStreamPtr st); + +int virNetClientStreamSetError(virNetClientStreamPtr st, + virNetMessagePtr msg); + +bool virNetClientStreamMatches(virNetClientStreamPtr st, + virNetMessagePtr msg); + +int virNetClientStreamQueuePacket(virNetClientStreamPtr st, + virNetMessagePtr msg); + +int virNetClientStreamSendPacket(virNetClientStreamPtr st, + virNetClientPtr client, + int status, + const char *data, + size_t nbytes); + +int virNetClientStreamRecvPacket(virNetClientStreamPtr st, + virNetClientPtr client, + char *data, + size_t nbytes, + bool nonblock); + +int virNetClientStreamEventAddCallback(virNetClientStreamPtr st, + int events, + virNetClientStreamEventCallback cb, + void *opaque, + virFreeCallback ff); + +int virNetClientStreamEventUpdateCallback(virNetClientStreamPtr st, + int events); +int virNetClientStreamEventRemoveCallback(virNetClientStreamPtr st); + + +#endif /* __VIR_NET_CLIENT_STREAM_H__ */ -- 1.7.4

Start of a trivial test case for the socket APIs. Only tests simple server setup & client connect for UNIX sockets so far * tests/Makefile.am: Add socket test * tests/virnetsockettest.c: New test case * tests/testutils.c: Avoid overriding LIBVIRT_DEBUG settings * tests/ssh.c: Dumb helper program for SSH tunnelling tests --- configure.ac | 2 +- tests/.gitignore | 2 + tests/Makefile.am | 14 ++- tests/ssh.c | 54 +++++ tests/testutils.c | 8 +- tests/virnetsockettest.c | 529 ++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 604 insertions(+), 5 deletions(-) create mode 100644 tests/ssh.c create mode 100644 tests/virnetsockettest.c diff --git a/configure.ac b/configure.ac index 81bad91..737a476 100644 --- a/configure.ac +++ b/configure.ac @@ -134,7 +134,7 @@ LIBS=$old_libs dnl Availability of various common headers (non-fatal if missing). AC_CHECK_HEADERS([pwd.h paths.h regex.h sys/syslimits.h sys/un.h \ sys/poll.h syslog.h mntent.h net/ethernet.h linux/magic.h \ - sys/un.h sys/syscall.h netinet/tcp.h fnmatch.h]) + sys/un.h sys/syscall.h netinet/tcp.h fnmatch.h ifaddrs.h]) AC_CHECK_LIB([intl],[gettext],[]) diff --git a/tests/.gitignore b/tests/.gitignore index e3906f0..e272cf6 100644 --- a/tests/.gitignore +++ b/tests/.gitignore @@ -1,6 +1,7 @@ *.exe .deps .libs +ssh commandhelper commandhelper.log commandhelper.pid @@ -30,6 +31,7 @@ statstest storagepoolxml2xmltest storagevolxml2xmltest virbuftest +virnetsockettest virshtest vmx2xmltest xencapstest diff --git a/tests/Makefile.am b/tests/Makefile.am index 5896442..abf8f30 100644 --- a/tests/Makefile.am +++ b/tests/Makefile.am @@ -77,7 +77,12 @@ EXTRA_DIST = \ check_PROGRAMS = virshtest conftest sockettest \ nodeinfotest qparamtest virbuftest \ - commandtest commandhelper seclabeltest + commandtest commandhelper seclabeltest \ + virnetsockettest ssh + +# This is a fake SSH we use from virnetsockettest +ssh_SOURCES = ssh.c +ssh_LDADD = $(COVERAGE_LDFLAGS) if WITH_XEN check_PROGRAMS += xml2sexprtest sexpr2xmltest \ @@ -159,6 +164,7 @@ TESTS = virshtest \ sockettest \ commandtest \ seclabeltest \ + virnetsockettest \ $(test_scripts) if WITH_XEN @@ -361,6 +367,12 @@ commandhelper_SOURCES = \ commandhelper_CFLAGS = -Dabs_builddir="\"`pwd`\"" commandhelper_LDADD = $(LDADDS) +virnetsockettest_SOURCES = \ + virnetsockettest.c testutils.h testutils.c +virnetsockettest_CFLAGS = -Dabs_builddir="\"$(abs_builddir)\"" +virnetsockettest_LDADD = $(LDADDS) + + seclabeltest_SOURCES = \ seclabeltest.c seclabeltest_LDADD = ../src/libvirt_driver_security.la $(LDADDS) diff --git a/tests/ssh.c b/tests/ssh.c new file mode 100644 index 0000000..ceaa29f --- /dev/null +++ b/tests/ssh.c @@ -0,0 +1,54 @@ +/* + * Copyright (C) 2006-2010 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#include <config.h> + +#include <stdio.h> +#include "internal.h" + +int main(int argc, char **argv) +{ + int i; + int failConnect = 0; /* Exit -1, with no data on stdout, msg on stderr */ + int dieEarly = 0; /* Exit -1, with partial data on stdout, msg on stderr */ + + for (i = 1 ; i < argc ; i++) { + if (STREQ(argv[i], "nosuchhost")) + failConnect = 1; + else if (STREQ(argv[i], "crashinghost")) + dieEarly = 1; + } + + if (failConnect) { + fprintf(stderr, "%s", "Cannot connect to host nosuchhost\n"); + return -1; + } + + if (dieEarly) { + printf("%s\n", "Hello World"); + fprintf(stderr, "%s", "Hangup from host\n"); + return -1; + } + + for (i = 1 ; i < argc ; i++) + printf("%s%c", argv[i], i == (argc -1) ? '\n' : ' '); + + return 0; +} diff --git a/tests/testutils.c b/tests/testutils.c index 3110457..9b3cf59 100644 --- a/tests/testutils.c +++ b/tests/testutils.c @@ -495,9 +495,11 @@ int virtTestMain(int argc, return 1; virLogSetFromEnv(); - if (virLogDefineOutput(virtTestLogOutput, virtTestLogClose, &testLog, - 0, 0, NULL, 0) < 0) - return 1; + if (!getenv("LIBVIRT_DEBUG") && !virLogGetNbOutputs()) { + if (virLogDefineOutput(virtTestLogOutput, virtTestLogClose, &testLog, + 0, 0, NULL, 0) < 0) + return 1; + } #if TEST_OOM if ((oomStr = getenv("VIR_TEST_OOM")) != NULL) { diff --git a/tests/virnetsockettest.c b/tests/virnetsockettest.c new file mode 100644 index 0000000..39d496a --- /dev/null +++ b/tests/virnetsockettest.c @@ -0,0 +1,529 @@ +/* + * Copyright (C) 2006-2010 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#include <config.h> + +#include <stdlib.h> +#include <signal.h> +#ifdef HAVE_IFADDRS_H +#include <ifaddrs.h> +#endif + +#include "testutils.h" +#include "util.h" +#include "virterror_internal.h" +#include "memory.h" +#include "logging.h" +#include "ignore-value.h" +#include "files.h" + +#include "rpc/virnetsocket.h" + +#define VIR_FROM_THIS VIR_FROM_RPC + +static char *argv0; +static char cwd[PATH_MAX]; + +#if HAVE_IFADDRS_H +#define BASE_PORT 5672 + +static int +checkProtocols(bool *hasIPv4, bool *hasIPv6, + int *freePort) +{ + struct ifaddrs *ifaddr = NULL, *ifa; + struct sockaddr_in in4; + struct sockaddr_in6 in6; + int s4 = -1, s6 = -1; + int i; + int ret = -1; + + *hasIPv4 = *hasIPv6 = false; + *freePort = 0; + + if (getifaddrs(&ifaddr) < 0) + goto cleanup; + + for (ifa = ifaddr; ifa != NULL; ifa = ifa->ifa_next) { + if (!ifa->ifa_addr) + continue; + + if (ifa->ifa_addr->sa_family == AF_INET) + *hasIPv4 = true; + if (ifa->ifa_addr->sa_family == AF_INET6) + *hasIPv6 = true; + } + + VIR_DEBUG("Protocols: v4 %d v6 %d\n", *hasIPv4, *hasIPv6); + + freeifaddrs(ifaddr); + + for (i = 0 ; i < 50 ; i++) { + int only = 1; + if ((s4 = socket(AF_INET, SOCK_STREAM, 0)) < 0) + goto cleanup; + + if ((s6 = socket(AF_INET6, SOCK_STREAM, 0)) < 0) + goto cleanup; + + if (setsockopt(s6, IPPROTO_IPV6, IPV6_V6ONLY, &only, sizeof(only)) < 0) + goto cleanup; + + memset(&in4, 0, sizeof(in4)); + memset(&in6, 0, sizeof(in6)); + + in4.sin_family = AF_INET; + in4.sin_port = htons(BASE_PORT + i); + in4.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + in6.sin6_family = AF_INET6; + in6.sin6_port = htons(BASE_PORT + i); + in6.sin6_addr = in6addr_loopback; + + if (bind(s4, (struct sockaddr *)&in4, sizeof(in4)) < 0) { + if (errno == EADDRINUSE) { + VIR_FORCE_CLOSE(s4); + VIR_FORCE_CLOSE(s6); + continue; + } + goto cleanup; + } + if (bind(s6, (struct sockaddr *)&in6, sizeof(in6)) < 0) { + if (errno == EADDRINUSE) { + VIR_FORCE_CLOSE(s4); + VIR_FORCE_CLOSE(s6); + continue; + } + goto cleanup; + } + + *freePort = BASE_PORT + i; + break; + } + + VIR_DEBUG("Choose port %d\n", *freePort); + + ret = 0; + +cleanup: + VIR_FORCE_CLOSE(s4); + VIR_FORCE_CLOSE(s6); + return ret; +} + + +struct testTCPData { + const char *lnode; + int port; + const char *cnode; +}; + +static int testSocketTCPAccept(const void *opaque) +{ + virNetSocketPtr *lsock = NULL; /* Listen socket */ + size_t nlsock = 0, i; + virNetSocketPtr ssock = NULL; /* Server socket */ + virNetSocketPtr csock = NULL; /* Client socket */ + const struct testTCPData *data = opaque; + int ret = -1; + char portstr[100]; + + snprintf(portstr, sizeof(portstr), "%d", data->port); + + if (virNetSocketNewListenTCP(data->lnode, portstr, &lsock, &nlsock) < 0) + goto cleanup; + + for (i = 0 ; i < nlsock ; i++) { + if (virNetSocketListen(lsock[i]) < 0) + goto cleanup; + } + + if (virNetSocketNewConnectTCP(data->cnode, portstr, &csock) < 0) + goto cleanup; + + virNetSocketFree(csock); + + for (i = 0 ; i < nlsock ; i++) { + if (virNetSocketAccept(lsock[i], &ssock) != -1 && ssock) { + char c = 'a'; + if (virNetSocketWrite(ssock, &c, 1) != -1 && + virNetSocketRead(ssock, &c, 1) != -1) { + VIR_DEBUG0("Unexpected client socket present"); + goto cleanup; + } + } + virNetSocketFree(ssock); + ssock = NULL; + } + + ret = 0; + +cleanup: + virNetSocketFree(ssock); + for (i = 0 ; i < nlsock ; i++) + virNetSocketFree(lsock[i]); + VIR_FREE(lsock); + return ret; +} +#endif + + +#ifndef WIN32 +static int testSocketUNIXAccept(const void *data ATTRIBUTE_UNUSED) +{ + virNetSocketPtr lsock = NULL; /* Listen socket */ + virNetSocketPtr ssock = NULL; /* Server socket */ + virNetSocketPtr csock = NULL; /* Client socket */ + int ret = -1; + + char *path; + if (argv0[0] == '/') { + if (virAsprintf(&path, "%s-test.sock", argv0) < 0) { + virReportOOMError(); + goto cleanup; + } + } else { + if (virAsprintf(&path, "%s/%s-test.sock", cwd, argv0) < 0) { + virReportOOMError(); + goto cleanup; + } + } + + if (virNetSocketNewListenUNIX(path, 0700, getgid(), &lsock) < 0) + goto cleanup; + + if (virNetSocketListen(lsock) < 0) + goto cleanup; + + if (virNetSocketNewConnectUNIX(path, false, NULL, &csock) < 0) + goto cleanup; + + virNetSocketFree(csock); + + if (virNetSocketAccept(lsock, &ssock) != -1) { + char c = 'a'; + if (virNetSocketWrite(ssock, &c, 1) != -1) { + VIR_DEBUG0("Unexpected client socket present"); + goto cleanup; + } + } + + ret = 0; + +cleanup: + VIR_FREE(path); + virNetSocketFree(lsock); + virNetSocketFree(ssock); + return ret; +} + + +static int testSocketUNIXAddrs(const void *data ATTRIBUTE_UNUSED) +{ + virNetSocketPtr lsock = NULL; /* Listen socket */ + virNetSocketPtr ssock = NULL; /* Server socket */ + virNetSocketPtr csock = NULL; /* Client socket */ + int ret = -1; + + char *path; + if (argv0[0] == '/') { + if (virAsprintf(&path, "%s-test.sock", argv0) < 0) { + virReportOOMError(); + goto cleanup; + } + } else { + if (virAsprintf(&path, "%s/%s-test.sock", cwd, argv0) < 0) { + virReportOOMError(); + goto cleanup; + } + } + + if (virNetSocketNewListenUNIX(path, 0700, getgid(), &lsock) < 0) + goto cleanup; + + if (STRNEQ(virNetSocketLocalAddrString(lsock), "127.0.0.1;0")) { + VIR_DEBUG0("Unexpected local address"); + goto cleanup; + } + + if (virNetSocketRemoteAddrString(lsock) != NULL) { + VIR_DEBUG0("Unexpected remote address"); + goto cleanup; + } + + if (virNetSocketListen(lsock) < 0) + goto cleanup; + + if (virNetSocketNewConnectUNIX(path, false, NULL, &csock) < 0) + goto cleanup; + + if (STRNEQ(virNetSocketLocalAddrString(csock), "127.0.0.1;0")) { + VIR_DEBUG0("Unexpected local address"); + goto cleanup; + } + + if (STRNEQ(virNetSocketRemoteAddrString(csock), "127.0.0.1;0")) { + VIR_DEBUG0("Unexpected local address"); + goto cleanup; + } + + + if (virNetSocketAccept(lsock, &ssock) < 0) { + VIR_DEBUG0("Unexpected client socket missing"); + goto cleanup; + } + + + if (STRNEQ(virNetSocketLocalAddrString(ssock), "127.0.0.1;0")) { + VIR_DEBUG0("Unexpected local address"); + goto cleanup; + } + + if (STRNEQ(virNetSocketRemoteAddrString(ssock), "127.0.0.1;0")) { + VIR_DEBUG0("Unexpected local address"); + goto cleanup; + } + + + ret = 0; + +cleanup: + VIR_FREE(path); + virNetSocketFree(lsock); + virNetSocketFree(ssock); + virNetSocketFree(csock); + return ret; +} + +static int testSocketCommandNormal(const void *data ATTRIBUTE_UNUSED) +{ + virNetSocketPtr csock = NULL; /* Client socket */ + char buf[100]; + size_t i; + int ret = -1; + virCommandPtr cmd = virCommandNewArgList("/bin/cat", "/dev/zero", NULL); + virCommandAddEnvPassCommon(cmd); + + if (virNetSocketNewConnectCommand(cmd, &csock) < 0) + goto cleanup; + + virNetSocketSetBlocking(csock, true); + + if (virNetSocketRead(csock, buf, sizeof(buf)) < 0) + goto cleanup; + + for (i = 0 ; i < sizeof(buf) ; i++) + if (buf[i] != '\0') + goto cleanup; + + ret = 0; + +cleanup: + virNetSocketFree(csock); + return ret; +} + +static int testSocketCommandFail(const void *data ATTRIBUTE_UNUSED) +{ + virNetSocketPtr csock = NULL; /* Client socket */ + char buf[100]; + int ret = -1; + virCommandPtr cmd = virCommandNewArgList("/bin/cat", "/dev/does-not-exist", NULL); + virCommandAddEnvPassCommon(cmd); + + if (virNetSocketNewConnectCommand(cmd, &csock) < 0) + goto cleanup; + + virNetSocketSetBlocking(csock, true); + + if (virNetSocketRead(csock, buf, sizeof(buf)) == 0) + goto cleanup; + + ret = 0; + +cleanup: + virNetSocketFree(csock); + return ret; +} + +struct testSSHData { + const char *nodename; + const char *service; + const char *binary; + const char *username; + bool noTTY; + const char *netcat; + const char *path; + + const char *expectOut; + bool failConnect; + bool dieEarly; +}; + +static int testSocketSSH(const void *opaque) +{ + const struct testSSHData *data = opaque; + virNetSocketPtr csock = NULL; /* Client socket */ + int ret = -1; + char buf[1024]; + + if (virNetSocketNewConnectSSH(data->nodename, + data->service, + data->binary, + data->username, + data->noTTY, + data->netcat, + data->path, + &csock) < 0) + goto cleanup; + + virNetSocketSetBlocking(csock, true); + + if (data->failConnect) { + if (virNetSocketRead(csock, buf, sizeof(buf)-1) >= 0) + goto cleanup; + } else { + ssize_t rv; + if ((rv = virNetSocketRead(csock, buf, sizeof(buf)-1)) < 0) + goto cleanup; + buf[rv] = '\0'; + + if (!STREQ(buf, data->expectOut)) { + virtTestDifference(stderr, data->expectOut, buf); + goto cleanup; + } + + if (data->dieEarly && + virNetSocketRead(csock, buf, sizeof(buf)-1) >= 0) + goto cleanup; + } + + ret = 0; + +cleanup: + virNetSocketFree(csock); + return ret; +} + +#endif + + +static int +mymain(int argc, char **argv) +{ + int ret = 0; +#ifdef HAVE_IFADDRS_H + bool hasIPv4, hasIPv6; + int freePort; +#endif + + argv0 = argv[0]; + + if (argc > 1) { + fprintf(stderr, "Usage: %s\n", argv0); + return (EXIT_FAILURE); + } + + signal(SIGPIPE, SIG_IGN); + + if (!(getcwd(cwd, sizeof(cwd)))) + return (EXIT_FAILURE); + +#ifdef HAVE_IFADDRS_H + if (checkProtocols(&hasIPv4, &hasIPv6, &freePort) < 0) { + fprintf(stderr, "Cannot identify IPv4/6 availability\n"); + return (EXIT_FAILURE); + } + + if (hasIPv4) { + struct testTCPData tcpData = { "127.0.0.1", freePort, "127.0.0.1" }; + if (virtTestRun("Socket TCP/IPv4 Accept", 1, testSocketTCPAccept, &tcpData) < 0) + ret = -1; + } + if (hasIPv6) { + struct testTCPData tcpData = { "::1", freePort, "::1" }; + if (virtTestRun("Socket TCP/IPv6 Accept", 1, testSocketTCPAccept, &tcpData) < 0) + ret = -1; + } + if (hasIPv6 && hasIPv4) { + struct testTCPData tcpData = { NULL, freePort, "127.0.0.1" }; + if (virtTestRun("Socket TCP/IPv4+IPv6 Accept", 1, testSocketTCPAccept, &tcpData) < 0) + ret = -1; + + tcpData.cnode = "::1"; + if (virtTestRun("Socket TCP/IPv4+IPv6 Accept", 1, testSocketTCPAccept, &tcpData) < 0) + ret = -1; + } +#endif + +#ifndef WIN32 + if (virtTestRun("Socket UNIX Accept", 1, testSocketUNIXAccept, NULL) < 0) + ret = -1; + + if (virtTestRun("Socket UNIX Addrs", 1, testSocketUNIXAddrs, NULL) < 0) + ret = -1; + + if (virtTestRun("Socket External Command /dev/zero", 1, testSocketCommandNormal, NULL) < 0) + ret = -1; + if (virtTestRun("Socket External Command /dev/does-not-exist", 1, testSocketCommandFail, NULL) < 0) + ret = -1; + + struct testSSHData sshData1 = { + .nodename = "somehost", + .path = "/tmp/socket", + .expectOut = "somehost nc -U /tmp/socket\n", + }; + if (virtTestRun("SSH test 1", 1, testSocketSSH, &sshData1) < 0) + ret = -1; + + struct testSSHData sshData2 = { + .nodename = "somehost", + .service = "9000", + .username = "fred", + .netcat = "netcat", + .noTTY = true, + .path = "/tmp/socket", + .expectOut = "-p 9000 -l fred -T -o BatchMode=yes -e none somehost netcat -U /tmp/socket\n", + }; + if (virtTestRun("SSH test 2", 1, testSocketSSH, &sshData2) < 0) + ret = -1; + + struct testSSHData sshData3 = { + .nodename = "nosuchhost", + .path = "/tmp/socket", + .failConnect = true, + }; + if (virtTestRun("SSH test 3", 1, testSocketSSH, &sshData3) < 0) + ret = -1; + + struct testSSHData sshData4 = { + .nodename = "crashyhost", + .path = "/tmp/socket", + .expectOut = "crashyhost nc -U /tmp/socket\n", + .dieEarly = true, + }; + if (virtTestRun("SSH test 4", 1, testSocketSSH, &sshData4) < 0) + ret = -1; + +#endif + + return (ret==0 ? EXIT_SUCCESS : EXIT_FAILURE); +} + +VIRT_TEST_MAIN(mymain) -- 1.7.4

Add a test case which validates the RPC message encoding & decoding to/from XDR representation. Covers the core message header, the error class and streams. * testutils.c, testutils.h: Helper for printing binary differences * virnetmessagetest.c: Validate all XDR encoding/decoding --- tests/.gitignore | 1 + tests/Makefile.am | 8 +- tests/testutils.c | 62 ++++++ tests/testutils.h | 4 + tests/virnetmessagetest.c | 509 +++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 583 insertions(+), 1 deletions(-) create mode 100644 tests/virnetmessagetest.c diff --git a/tests/.gitignore b/tests/.gitignore index e272cf6..7f26dd7 100644 --- a/tests/.gitignore +++ b/tests/.gitignore @@ -31,6 +31,7 @@ statstest storagepoolxml2xmltest storagevolxml2xmltest virbuftest +virnetmessagetest virnetsockettest virshtest vmx2xmltest diff --git a/tests/Makefile.am b/tests/Makefile.am index abf8f30..841993b 100644 --- a/tests/Makefile.am +++ b/tests/Makefile.am @@ -78,7 +78,7 @@ EXTRA_DIST = \ check_PROGRAMS = virshtest conftest sockettest \ nodeinfotest qparamtest virbuftest \ commandtest commandhelper seclabeltest \ - virnetsockettest ssh + virnetsockettest ssh virnetmessagetest # This is a fake SSH we use from virnetsockettest ssh_SOURCES = ssh.c @@ -165,6 +165,7 @@ TESTS = virshtest \ commandtest \ seclabeltest \ virnetsockettest \ + virnetmessagetest \ $(test_scripts) if WITH_XEN @@ -372,6 +373,11 @@ virnetsockettest_SOURCES = \ virnetsockettest_CFLAGS = -Dabs_builddir="\"$(abs_builddir)\"" virnetsockettest_LDADD = $(LDADDS) +virnetmessagetest_SOURCES = \ + virnetmessagetest.c testutils.h testutils.c +virnetmessagetest_CFLAGS = -Dabs_builddir="\"$(abs_builddir)\"" +virnetmessagetest_LDADD = $(LDADDS) + seclabeltest_SOURCES = \ seclabeltest.c diff --git a/tests/testutils.c b/tests/testutils.c index 9b3cf59..6c304f8 100644 --- a/tests/testutils.c +++ b/tests/testutils.c @@ -374,6 +374,68 @@ int virtTestDifference(FILE *stream, return 0; } +/** + * @param stream: output stream write to differences to + * @param expect: expected output text + * @param actual: actual output text + * + * Display expected and actual output text, trimmed to + * first and last characters at which differences occur + */ +int virtTestDifferenceBin(FILE *stream, + const char *expect, + const char *actual, + size_t length) +{ + size_t start = 0, end = length; + ssize_t i; + + if (!virTestGetDebug()) + return 0; + + if (virTestGetDebug() < 2) { + /* Skip to first character where they differ */ + for (i = 0 ; i < length ; i++) { + if (expect[i] != actual[i]) { + start = i; + break; + } + } + + /* Work backwards to last character where they differ */ + for (i = (length -1) ; i >= 0 ; i--) { + if (expect[i] != actual[i]) { + end = i; + break; + } + } + } + /* Round to nearest boundary of 4 */ + start -= (start % 4); + end += 4 - (end % 4); + + /* Show the trimmed differences */ + fprintf(stream, "\nExpect [ Region %d-%d", (int)start, (int)end); + for (i = start; i < end ; i++) { + if ((i % 4) == 0) + fprintf(stream, "\n "); + fprintf(stream, "0x%02x, ", ((int)expect[i])&0xff); + } + fprintf(stream, "]\n"); + fprintf(stream, "Actual [ Region %d-%d", (int)start, (int)end); + for (i = start; i < end ; i++) { + if ((i % 4) == 0) + fprintf(stream, "\n "); + fprintf(stream, "0x%02x, ", ((int)actual[i])&0xff); + } + fprintf(stream, "]\n"); + + /* Pad to line up with test name ... in virTestRun */ + fprintf(stream, " ... "); + + return 0; +} + #if TEST_OOM static void virtTestErrorFuncQuiet(void *data ATTRIBUTE_UNUSED, diff --git a/tests/testutils.h b/tests/testutils.h index 88603a1..0a7321a 100644 --- a/tests/testutils.h +++ b/tests/testutils.h @@ -36,6 +36,10 @@ int virtTestClearLineRegex(const char *pattern, int virtTestDifference(FILE *stream, const char *expect, const char *actual); +int virtTestDifferenceBin(FILE *stream, + const char *expect, + const char *actual, + size_t length); unsigned int virTestGetDebug(void); unsigned int virTestGetVerbose(void); diff --git a/tests/virnetmessagetest.c b/tests/virnetmessagetest.c new file mode 100644 index 0000000..cb53845 --- /dev/null +++ b/tests/virnetmessagetest.c @@ -0,0 +1,509 @@ +/* + * Copyright (C) 2010 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#include <config.h> + +#include <stdlib.h> +#include <signal.h> + +#include "testutils.h" +#include "util.h" +#include "virterror_internal.h" +#include "memory.h" +#include "logging.h" +#include "ignore-value.h" + +#include "rpc/virnetmessage.h" + +#define VIR_FROM_THIS VIR_FROM_RPC + +static char *argv0; +static char cwd[PATH_MAX]; + +static int testMessageHeaderEncode(const void *args ATTRIBUTE_UNUSED) +{ + virNetMessage msg; + const char expect[] = { + 0x00, 0x00, 0x00, 0x1c, /* Length */ + 0x11, 0x22, 0x33, 0x44, /* Program */ + 0x00, 0x00, 0x00, 0x01, /* Version */ + 0x00, 0x00, 0x06, 0x66, /* Procedure */ + 0x00, 0x00, 0x00, 0x00, /* Type */ + 0x00, 0x00, 0x00, 0x99, /* Serial */ + 0x00, 0x00, 0x00, 0x00, /* Status */ + }; + memset(&msg, 0, sizeof(msg)); + + msg.header.prog = 0x11223344; + msg.header.vers = 0x01; + msg.header.proc = 0x666; + msg.header.type = VIR_NET_CALL; + msg.header.serial = 0x99; + msg.header.status = VIR_NET_OK; + + if (virNetMessageEncodeHeader(&msg) < 0) + return -1; + + if (ARRAY_CARDINALITY(expect) != msg.bufferOffset) { + VIR_DEBUG("Expect message offset %zu got %zu", + sizeof(expect), msg.bufferOffset); + return -1; + } + + if (msg.bufferLength != sizeof(msg.buffer)) { + VIR_DEBUG("Expect message offset %zu got %zu", + sizeof(msg.buffer), msg.bufferLength); + return -1; + } + + if (memcmp(expect, msg.buffer, sizeof(expect)) != 0) { + virtTestDifferenceBin(stderr, expect, msg.buffer, sizeof(expect)); + return -1; + } + + return 0; +} + +static int testMessageHeaderDecode(const void *args ATTRIBUTE_UNUSED) +{ + virNetMessage msg = { + .bufferOffset = 0, + .bufferLength = 0x4, + .buffer = { + 0x00, 0x00, 0x00, 0x1c, /* Length */ + 0x11, 0x22, 0x33, 0x44, /* Program */ + 0x00, 0x00, 0x00, 0x01, /* Version */ + 0x00, 0x00, 0x06, 0x66, /* Procedure */ + 0x00, 0x00, 0x00, 0x01, /* Type */ + 0x00, 0x00, 0x00, 0x99, /* Serial */ + 0x00, 0x00, 0x00, 0x01, /* Status */ + }, + .header = { 0, 0, 0, 0, 0, 0 }, + }; + + msg.header.prog = 0x11223344; + msg.header.vers = 0x01; + msg.header.proc = 0x666; + msg.header.type = VIR_NET_CALL; + msg.header.serial = 0x99; + msg.header.status = VIR_NET_OK; + + if (virNetMessageDecodeLength(&msg) < 0) { + VIR_DEBUG0("Failed to decode message header"); + return -1; + } + + if (msg.bufferOffset != 0x4) { + VIR_DEBUG("Expecting offset %zu got %zu", + (size_t)4, msg.bufferOffset); + return -1; + } + + if (msg.bufferLength != 0x1c) { + VIR_DEBUG("Expecting length %zu got %zu", + (size_t)0x1c, msg.bufferLength); + return -1; + } + + if (virNetMessageDecodeHeader(&msg) < 0) { + VIR_DEBUG0("Failed to decode message header"); + return -1; + } + + if (msg.bufferOffset != msg.bufferLength) { + VIR_DEBUG("Expect message offset %zu got %zu", + msg.bufferOffset, msg.bufferLength); + return -1; + } + + if (msg.header.prog != 0x11223344) { + VIR_DEBUG("Expect prog %d got %d", + 0x11223344, msg.header.prog); + return -1; + } + if (msg.header.vers != 0x1) { + VIR_DEBUG("Expect vers %d got %d", + 0x11223344, msg.header.vers); + return -1; + } + if (msg.header.proc != 0x666) { + VIR_DEBUG("Expect proc %d got %d", + 0x666, msg.header.proc); + return -1; + } + if (msg.header.type != VIR_NET_REPLY) { + VIR_DEBUG("Expect type %d got %d", + VIR_NET_REPLY, msg.header.type); + return -1; + } + if (msg.header.serial != 0x99) { + VIR_DEBUG("Expect serial %d got %d", + 0x99, msg.header.serial); + return -1; + } + if (msg.header.status != VIR_NET_ERROR) { + VIR_DEBUG("Expect status %d got %d", + VIR_NET_ERROR, msg.header.status); + return -1; + } + + return 0; +} + +static int testMessagePayloadEncode(const void *args ATTRIBUTE_UNUSED) +{ + virNetMessage msg; + virNetMessageError err; + const char expect[] = { + 0x00, 0x00, 0x00, 0x74, /* Length */ + 0x11, 0x22, 0x33, 0x44, /* Program */ + 0x00, 0x00, 0x00, 0x01, /* Version */ + 0x00, 0x00, 0x06, 0x66, /* Procedure */ + 0x00, 0x00, 0x00, 0x02, /* Type */ + 0x00, 0x00, 0x00, 0x99, /* Serial */ + 0x00, 0x00, 0x00, 0x01, /* Status */ + + 0x00, 0x00, 0x00, 0x01, /* Error code */ + 0x00, 0x00, 0x00, 0x07, /* Error domain */ + 0x00, 0x00, 0x00, 0x01, /* Error message pointer */ + 0x00, 0x00, 0x00, 0x0b, /* Error message length */ + 'H', 'e', 'l', 'l', /* Error message string */ + 'o', ' ', 'W', 'o', + 'r', 'l', 'd', '\0', + 0x00, 0x00, 0x00, 0x02, /* Error level */ + 0x00, 0x00, 0x00, 0x00, /* Error domain pointer */ + 0x00, 0x00, 0x00, 0x01, /* Error str1 pointer */ + 0x00, 0x00, 0x00, 0x03, /* Error str1 length */ + 'O', 'n', 'e', '\0', /* Error str1 message */ + 0x00, 0x00, 0x00, 0x01, /* Error str2 pointer */ + 0x00, 0x00, 0x00, 0x03, /* Error str2 length */ + 'T', 'w', 'o', '\0', /* Error str2 message */ + 0x00, 0x00, 0x00, 0x01, /* Error str3 pointer */ + 0x00, 0x00, 0x00, 0x05, /* Error str3 length */ + 'T', 'h', 'r', 'e', /* Error str3 message */ + 'e', '\0', '\0', '\0', + 0x00, 0x00, 0x00, 0x01, /* Error int1 */ + 0x00, 0x00, 0x00, 0x02, /* Error int2 */ + 0x00, 0x00, 0x00, 0x00, /* Error network pointer */ + }; + memset(&msg, 0, sizeof(msg)); + memset(&err, 0, sizeof(err)); + + err.code = VIR_ERR_INTERNAL_ERROR; + err.domain = VIR_FROM_RPC; + if (VIR_ALLOC(err.message) < 0) + return -1; + *err.message = strdup("Hello World"); + err.level = VIR_ERR_ERROR; + if (VIR_ALLOC(err.str1) < 0) + return -1; + *err.str1 = strdup("One"); + if (VIR_ALLOC(err.str2) < 0) + return -1; + *err.str2 = strdup("Two"); + if (VIR_ALLOC(err.str3) < 0) + return -1; + *err.str3 = strdup("Three"); + err.int1 = 1; + err.int2 = 2; + + msg.header.prog = 0x11223344; + msg.header.vers = 0x01; + msg.header.proc = 0x666; + msg.header.type = VIR_NET_MESSAGE; + msg.header.serial = 0x99; + msg.header.status = VIR_NET_ERROR; + + if (virNetMessageEncodeHeader(&msg) < 0) + return -1; + + if (virNetMessageEncodePayload(&msg, (xdrproc_t)xdr_virNetMessageError, &err) < 0) + return -1; + + if (ARRAY_CARDINALITY(expect) != msg.bufferLength) { + VIR_DEBUG("Expect message length %zu got %zu", + sizeof(expect), msg.bufferLength); + return -1; + } + + if (msg.bufferOffset != 0) { + VIR_DEBUG("Expect message offset 0 got %zu", + msg.bufferOffset); + return -1; + } + + if (memcmp(expect, msg.buffer, sizeof(expect)) != 0) { + virtTestDifferenceBin(stderr, expect, msg.buffer, sizeof(expect)); + return -1; + } + + return 0; +} + +static int testMessagePayloadDecode(const void *args ATTRIBUTE_UNUSED) +{ + virNetMessageError err; + virNetMessage msg = { + .bufferOffset = 0, + .bufferLength = 0x4, + .buffer = { + 0x00, 0x00, 0x00, 0x74, /* Length */ + 0x11, 0x22, 0x33, 0x44, /* Program */ + 0x00, 0x00, 0x00, 0x01, /* Version */ + 0x00, 0x00, 0x06, 0x66, /* Procedure */ + 0x00, 0x00, 0x00, 0x02, /* Type */ + 0x00, 0x00, 0x00, 0x99, /* Serial */ + 0x00, 0x00, 0x00, 0x01, /* Status */ + + 0x00, 0x00, 0x00, 0x01, /* Error code */ + 0x00, 0x00, 0x00, 0x07, /* Error domain */ + 0x00, 0x00, 0x00, 0x01, /* Error message pointer */ + 0x00, 0x00, 0x00, 0x0b, /* Error message length */ + 'H', 'e', 'l', 'l', /* Error message string */ + 'o', ' ', 'W', 'o', + 'r', 'l', 'd', '\0', + 0x00, 0x00, 0x00, 0x02, /* Error level */ + 0x00, 0x00, 0x00, 0x00, /* Error domain pointer */ + 0x00, 0x00, 0x00, 0x01, /* Error str1 pointer */ + 0x00, 0x00, 0x00, 0x03, /* Error str1 length */ + 'O', 'n', 'e', '\0', /* Error str1 message */ + 0x00, 0x00, 0x00, 0x01, /* Error str2 pointer */ + 0x00, 0x00, 0x00, 0x03, /* Error str2 length */ + 'T', 'w', 'o', '\0', /* Error str2 message */ + 0x00, 0x00, 0x00, 0x01, /* Error str3 pointer */ + 0x00, 0x00, 0x00, 0x05, /* Error str3 length */ + 'T', 'h', 'r', 'e', /* Error str3 message */ + 'e', '\0', '\0', '\0', + 0x00, 0x00, 0x00, 0x01, /* Error int1 */ + 0x00, 0x00, 0x00, 0x02, /* Error int2 */ + 0x00, 0x00, 0x00, 0x00, /* Error network pointer */ + }, + .header = { 0, 0, 0, 0, 0, 0 }, + }; + memset(&err, 0, sizeof(err)); + + if (virNetMessageDecodeLength(&msg) < 0) { + VIR_DEBUG0("Failed to decode message header"); + return -1; + } + + if (msg.bufferOffset != 0x4) { + VIR_DEBUG("Expecting offset %zu got %zu", + (size_t)4, msg.bufferOffset); + return -1; + } + + if (msg.bufferLength != 0x74) { + VIR_DEBUG("Expecting length %zu got %zu", + (size_t)0x74, msg.bufferLength); + return -1; + } + + if (virNetMessageDecodeHeader(&msg) < 0) { + VIR_DEBUG0("Failed to decode message header"); + return -1; + } + + if (msg.bufferOffset != 28) { + VIR_DEBUG("Expect message offset %zu got %zu", + msg.bufferOffset, (size_t)28); + return -1; + } + + if (msg.bufferLength != 0x74) { + VIR_DEBUG("Expecting length %zu got %zu", + (size_t)0x1c, msg.bufferLength); + return -1; + } + + if (virNetMessageDecodePayload(&msg, (xdrproc_t)xdr_virNetMessageError, &err) < 0) { + VIR_DEBUG0("Failed to decode message payload"); + return -1; + } + + if (err.code != VIR_ERR_INTERNAL_ERROR) { + VIR_DEBUG("Expect code %d got %d", + VIR_ERR_INTERNAL_ERROR, err.code); + return -1; + } + + if (err.domain != VIR_FROM_RPC) { + VIR_DEBUG("Expect domain %d got %d", + VIR_ERR_RPC, err.domain); + return -1; + } + + if (err.message == NULL || + STRNEQ(*err.message, "Hello World")) { + VIR_DEBUG("Expect str1 'Hello World' got %s", + err.message ? *err.message : "(null)"); + return -1; + } + + if (err.dom != NULL) { + VIR_DEBUG0("Expect NULL dom"); + return -1; + } + + if (err.level != VIR_ERR_ERROR) { + VIR_DEBUG("Expect leve %d got %d", + VIR_ERR_ERROR, err.level); + return -1; + } + + if (err.str1 == NULL || + STRNEQ(*err.str1, "One")) { + VIR_DEBUG("Expect str1 'One' got %s", + err.str1 ? *err.str1 : "(null)"); + return -1; + } + + if (err.str2 == NULL || + STRNEQ(*err.str2, "Two")) { + VIR_DEBUG("Expect str3 'Two' got %s", + err.str2 ? *err.str2 : "(null)"); + return -1; + } + + if (err.str3 == NULL || + STRNEQ(*err.str3, "Three")) { + VIR_DEBUG("Expect str3 'Three' got %s", + err.str3 ? *err.str3 : "(null)"); + return -1; + } + + if (err.int1 != 1) { + VIR_DEBUG("Expect int1 1 got %d", + err.int1); + return -1; + } + + if (err.int2 != 2) { + VIR_DEBUG("Expect int2 2 got %d", + err.int2); + return -1; + } + + if (err.net != NULL) { + VIR_DEBUG0("Expect NULL network"); + return -1; + } + + xdr_free((xdrproc_t)xdr_virNetMessageError, (void*)&err); + return 0; +} + +static int testMessagePayloadStreamEncode(const void *args ATTRIBUTE_UNUSED) +{ + char stream[] = "The quick brown fox jumps over the lazy dog"; + virNetMessage msg; + const char expect[] = { + 0x00, 0x00, 0x00, 0x47, /* Length */ + 0x11, 0x22, 0x33, 0x44, /* Program */ + 0x00, 0x00, 0x00, 0x01, /* Version */ + 0x00, 0x00, 0x06, 0x66, /* Procedure */ + 0x00, 0x00, 0x00, 0x03, /* Type */ + 0x00, 0x00, 0x00, 0x99, /* Serial */ + 0x00, 0x00, 0x00, 0x02, /* Status */ + + 'T', 'h', 'e', ' ', + 'q', 'u', 'i', 'c', + 'k', ' ', 'b', 'r', + 'o', 'w', 'n', ' ', + 'f', 'o', 'x', ' ', + 'j', 'u', 'm', 'p', + 's', ' ', 'o', 'v', + 'e', 'r', ' ', 't', + 'h', 'e', ' ', 'l', + 'a', 'z', 'y', ' ', + 'd', 'o', 'g', + }; + memset(&msg, 0, sizeof(msg)); + + msg.header.prog = 0x11223344; + msg.header.vers = 0x01; + msg.header.proc = 0x666; + msg.header.type = VIR_NET_STREAM; + msg.header.serial = 0x99; + msg.header.status = VIR_NET_CONTINUE; + + if (virNetMessageEncodeHeader(&msg) < 0) + return -1; + + if (virNetMessageEncodePayloadRaw(&msg, stream, strlen(stream)) < 0) + return -1; + + if (ARRAY_CARDINALITY(expect) != msg.bufferLength) { + VIR_DEBUG("Expect message length %zu got %zu", + sizeof(expect), msg.bufferLength); + return -1; + } + + if (msg.bufferOffset != 0) { + VIR_DEBUG("Expect message offset 0 got %zu", + msg.bufferOffset); + return -1; + } + + if (memcmp(expect, msg.buffer, sizeof(expect)) != 0) { + virtTestDifferenceBin(stderr, expect, msg.buffer, sizeof(expect)); + return -1; + } + + return 0; +} + + +static int +mymain(int argc, char **argv) +{ + int ret = 0; + + argv0 = argv[0]; + + if (argc > 1) { + fprintf(stderr, "Usage: %s\n", argv0); + return (EXIT_FAILURE); + } + + signal(SIGPIPE, SIG_IGN); + + if (!(getcwd(cwd, sizeof(cwd)))) + return (EXIT_FAILURE); + + if (virtTestRun("Message Header Encode", 1, testMessageHeaderEncode, NULL) < 0) + ret = -1; + + if (virtTestRun("Message Header Decode", 1, testMessageHeaderDecode, NULL) < 0) + ret = -1; + + if (virtTestRun("Message Payload Encode", 1, testMessagePayloadEncode, NULL) < 0) + ret = -1; + + if (virtTestRun("Message Payload Decode", 1, testMessagePayloadDecode, NULL) < 0) + ret = -1; + + if (virtTestRun("Message Payload Stream Encode", 1, testMessagePayloadStreamEncode, NULL) < 0) + ret = -1; + + return (ret==0 ? EXIT_SUCCESS : EXIT_FAILURE); +} + +VIRT_TEST_MAIN(mymain) -- 1.7.4
participants (2)
-
Daniel P. Berrange
-
Eric Blake