[libvirt] [PATCH 00/10 v5] New internal infrastructure for RPC

An update of https://www.redhat.com/archives/libvir-list/2011-May/msg00633.html New in this series: - Nothing. Just rebased to latest GIT. Again, I'm not posting the remote client / daemon conversion patches here, since I want to get this core infrastructure merged real soon and the client/daemon conversion would delay review Full series also available at http://gitorious.org/~berrange/libvirt/staging/commits/rpc

This patch defines the basics of a generic RPC protocol in XDR. This is wire ABI compatible with the original remote_protocol.x. It takes everything except for the RPC calls / events from that protocol - The basic header virNetMessageHeader (aka remote_message_header) - The error object virNetMessageError (aka remote_error) - Two dummy objects virNetMessageDomain & virNetMessageNetwork sadly needed to keep virNetMessageError ABI compatible with the old remote_error The RPC protocol supports method calls, async events and bidirectional data streams as before * src/Makefile.am: Add rules for generating RPC code from protocol & define a new libvirt-net-rpc.la helper library * src/rpc/virnetprotocol.x: New generic RPC protocol * src/rpc/virnetprotocol.c, src/rpc/virnetprotocol.h: Generated from virnetprotocol.x --- src/Makefile.am | 20 ++++- src/rpc/virnetprotocol.x | 217 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 233 insertions(+), 4 deletions(-) create mode 100644 src/rpc/virnetprotocol.x diff --git a/src/Makefile.am b/src/Makefile.am index 58e9f82..45905fa 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -552,13 +552,11 @@ libvirt_driver_remote_la_SOURCES = $(REMOTE_DRIVER_SOURCES) $(srcdir)/remote/remote_driver.c: $(REMOTE_DRIVER_GENERATED) -$(srcdir)/remote/%_protocol.c: $(srcdir)/remote/%_protocol.x \ - $(srcdir)/remote/%_protocol.h $(srcdir)/remote/rpcgen_fix.pl +%protocol.c: %protocol.x %protocol.h $(srcdir)/remote/rpcgen_fix.pl $(AM_V_GEN)perl -w $(srcdir)/remote/rpcgen_fix.pl $(RPCGEN) -c \ $< $@ -$(srcdir)/remote/%_protocol.h: $(srcdir)/remote/%_protocol.x \ - $(srcdir)/remote/rpcgen_fix.pl +%protocol.h: %protocol.x $(srcdir)/remote/rpcgen_fix.pl $(AM_V_GEN)perl -w $(srcdir)/remote/rpcgen_fix.pl $(RPCGEN) -h \ $< $@ @@ -1155,6 +1153,20 @@ libvirt_qemu_la_CFLAGS = $(AM_CFLAGS) libvirt_qemu_la_LIBADD = libvirt.la $(CYGWIN_EXTRA_LIBADD) EXTRA_DIST += $(LIBVIRT_QEMU_SYMBOL_FILE) + +noinst_LTLIBRARIES += libvirt-net-rpc.la + +libvirt_net_rpc_la_SOURCES = \ + rpc/virnetprotocol.h rpc/virnetprotocol.c +libvirt_net_rpc_la_CFLAGS = \ + $(AM_CFLAGS) +libvirt_net_rpc_la_LDFLAGS = \ + $(AM_LDFLAGS) \ + $(CYGWIN_EXTRA_LDFLAGS) \ + $(MINGW_EXTRA_LDFLAGS) +libvirt_net_rpc_la_LIBADD = \ + $(CYGWIN_EXTRA_LIBADD) + libexec_PROGRAMS = if WITH_LIBVIRTD diff --git a/src/rpc/virnetprotocol.x b/src/rpc/virnetprotocol.x new file mode 100644 index 0000000..15066b8 --- /dev/null +++ b/src/rpc/virnetprotocol.x @@ -0,0 +1,217 @@ +/* -*- c -*- + * virnetprotocol.x: basic protocol for all RPC services. + * + * Copyright (C) 2006-2010 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Richard Jones <rjones@redhat.com> + */ + +%#include "internal.h" +%#include <arpa/inet.h> + +/* cygwin's xdr implementation defines xdr_u_int64_t instead of xdr_uint64_t + * and lacks IXDR_PUT_INT32 and IXDR_GET_INT32 + */ +%#ifdef HAVE_XDR_U_INT64_T +%# define xdr_uint64_t xdr_u_int64_t +%#endif +%#ifndef IXDR_PUT_INT32 +%# define IXDR_PUT_INT32 IXDR_PUT_LONG +%#endif +%#ifndef IXDR_GET_INT32 +%# define IXDR_GET_INT32 IXDR_GET_LONG +%#endif +%#ifndef IXDR_PUT_U_INT32 +%# define IXDR_PUT_U_INT32 IXDR_PUT_U_LONG +%#endif +%#ifndef IXDR_GET_U_INT32 +%# define IXDR_GET_U_INT32 IXDR_GET_U_LONG +%#endif + +/*----- Data types. -----*/ + +/* Maximum total message size (serialised). */ +const VIR_NET_MESSAGE_MAX = 262144; + +/* Size of struct virNetMessageHeader (serialised)*/ +const VIR_NET_MESSAGE_HEADER_MAX = 24; + +/* Size of message payload */ +const VIR_NET_MESSAGE_PAYLOAD_MAX = 262120; + +/* Size of message length field. Not counted in VIR_NET_MESSAGE_MAX */ +const VIR_NET_MESSAGE_LEN_MAX = 4; + +/* Length of long, but not unbounded, strings. + * This is an arbitrary limit designed to stop the decoder from trying + * to allocate unbounded amounts of memory when fed with a bad message. + */ +const VIR_NET_MESSAGE_STRING_MAX = 65536; + +/* + * RPC wire format + * + * Each message consists of: + * + * Name | Type | Description + * -----------+-----------------------+------------------ + * Length | int | Total number of bytes in message _including_ length. + * Header | virNetMessageHeader | Control information about procedure call + * Payload | - | Variable payload data per procedure + * + * In header, the 'serial' field varies according to: + * + * - type == VIR_NET_CALL + * * serial is set by client, incrementing by 1 each time + * + * - type == VIR_NET_REPLY + * * serial matches that from the corresponding VIR_NET_CALL + * + * - type == VIR_NET_MESSAGE + * * serial is always zero + * + * - type == VIR_NET_STREAM + * * serial matches that from the corresponding VIR_NET_CALL + * + * and the 'status' field varies according to: + * + * - type == VIR_NET_CALL + * * VIR_NET_OK always + * + * - type == VIR_NET_REPLY + * * VIR_NET_OK if RPC finished successfully + * * VIR_NET_ERROR if something failed + * + * - type == VIR_NET_MESSAGE + * * VIR_NET_OK always + * + * - type == VIR_NET_STREAM + * * VIR_NET_CONTINUE if more data is following + * * VIR_NET_OK if stream is complete + * * VIR_NET_ERROR if stream had an error + * + * Payload varies according to type and status: + * + * - type == VIR_NET_CALL + * XXX_args for procedure + * + * - type == VIR_NET_REPLY + * * status == VIR_NET_OK + * XXX_ret for procedure + * * status == VIR_NET_ERROR + * remote_error Error information + * + * - type == VIR_NET_MESSAGE + * * status == VIR_NET_OK + * XXX_msg for event information + * + * - type == VIR_NET_STREAM + * * status == VIR_NET_CONTINUE + * byte[] raw stream data + * * status == VIR_NET_ERROR + * remote_error error information + * * status == VIR_NET_OK + * <empty> + */ +enum virNetMessageType { + /* client -> server. args from a method call */ + VIR_NET_CALL = 0, + /* server -> client. reply/error from a method call */ + VIR_NET_REPLY = 1, + /* either direction. async notification */ + VIR_NET_MESSAGE = 2, + /* either direction. stream data packet */ + VIR_NET_STREAM = 3 +}; + +enum virNetMessageStatus { + /* Status is always VIR_NET_OK for calls. + * For replies, indicates no error. + */ + VIR_NET_OK = 0, + + /* For replies, indicates that an error happened, and a struct + * remote_error follows. + */ + VIR_NET_ERROR = 1, + + /* For streams, indicates that more data is still expected + */ + VIR_NET_CONTINUE = 2 +}; + +/* 4 byte length word per header */ +const VIR_NET_MESSAGE_HEADER_XDR_LEN = 4; + +struct virNetMessageHeader { + unsigned prog; /* Unique ID for the program */ + unsigned vers; /* Program version number */ + int proc; /* Unique ID for the procedure within the program */ + virNetMessageType type; /* Type of message */ + unsigned serial; /* Serial number of message. */ + virNetMessageStatus status; +}; + +/* Error message. See <virterror.h> for explanation of fields. */ + +/* Most of these don't really belong here. There are sadly needed + * for wire ABI backwards compatibility with the rather crazy + * error struct we previously defined :-( + */ + +typedef opaque virNetMessageUUID[VIR_UUID_BUFLEN]; +typedef string virNetMessageNonnullString<VIR_NET_MESSAGE_STRING_MAX>; + +/* A long string, which may be NULL. */ +typedef virNetMessageNonnullString *virNetMessageString; + +/* A domain which may not be NULL. */ +struct virNetMessageNonnullDomain { + virNetMessageNonnullString name; + virNetMessageUUID uuid; + int id; +}; + +/* A network which may not be NULL. */ +struct virNetMessageNonnullNetwork { + virNetMessageNonnullString name; + virNetMessageUUID uuid; +}; + + +typedef virNetMessageNonnullDomain *virNetMessageDomain; +typedef virNetMessageNonnullNetwork *virNetMessageNetwork; + +/* NB. Fields "code", "domain" and "level" are really enums. The + * numeric value should remain compatible between libvirt and + * libvirtd. This means, no changing or reordering the enums as + * defined in <virterror.h> (but we don't do that anyway, for separate + * ABI reasons). + */ +struct virNetMessageError { + int code; + int domain; + virNetMessageString message; + int level; + virNetMessageDomain dom; /* unused */ + virNetMessageString str1; + virNetMessageString str2; + virNetMessageString str3; + int int1; + int int2; + virNetMessageNetwork net; /* unused */ +}; -- 1.7.4.4

This provides a new struct that contains a buffer for the RPC message header+payload, as well as a decoded copy of the message header. There is an API for applying a XDR encoding & decoding of the message headers and payloads. There are also APIs for maintaining a simple FIFO queue of message instances. Expected usage scenarios are: To send a message msg = virNetMessageNew() ...fill in msg->header fields.. virNetMessageEncodeHeader(msg) ...loook at msg->header fields to determine payload filter virNetMessageEncodePayload(msg, xdrfilter, data) ...send msg->bufferLength worth of data from buffer To receive a message msg = virNetMessageNew() ...read VIR_NET_MESSAGE_LEN_MAX of data into buffer virNetMessageDecodeLength(msg) ...read msg->bufferLength-msg->bufferOffset of data into buffer virNetMessageDecodeHeader(msg) ...look at msg->header fields to determine payload filter virNetMessageDecodePayload(msg, xdrfilter, data) ...run payload processor * src/Makefile.am: Add to libvirt-net-rpc.la * src/rpc/virnetmessage.c, src/rpc/virnetmessage.h: Internal message handling API. * testutils.c, testutils.h: Helper for printing binary differences * virnetmessagetest.c: Validate all XDR encoding/decoding --- cfg.mk | 1 + po/POTFILES.in | 1 + src/Makefile.am | 1 + src/rpc/virnetmessage.c | 365 +++++++++++++++++++++++++++++++++ src/rpc/virnetmessage.h | 82 ++++++++ tests/.gitignore | 1 + tests/Makefile.am | 9 +- tests/testutils.c | 62 ++++++ tests/testutils.h | 4 + tests/virnetmessagetest.c | 496 +++++++++++++++++++++++++++++++++++++++++++++ 10 files changed, 1021 insertions(+), 1 deletions(-) create mode 100644 src/rpc/virnetmessage.c create mode 100644 src/rpc/virnetmessage.h create mode 100644 tests/virnetmessagetest.c diff --git a/cfg.mk b/cfg.mk index 3a10186..cf30929 100644 --- a/cfg.mk +++ b/cfg.mk @@ -125,6 +125,7 @@ useless_free_options = \ --name=virInterfaceProtocolDefFree \ --name=virJSONValueFree \ --name=virLastErrFreeData \ + --name=virNetMessageFree \ --name=virNWFilterDefFree \ --name=virNWFilterEntryFree \ --name=virNWFilterHashTableFree \ diff --git a/po/POTFILES.in b/po/POTFILES.in index dd44da2..a6048ec 100644 --- a/po/POTFILES.in +++ b/po/POTFILES.in @@ -67,6 +67,7 @@ src/qemu/qemu_monitor_text.c src/qemu/qemu_process.c src/remote/remote_client_bodies.h src/remote/remote_driver.c +src/rpc/virnetmessage.c src/secret/secret_driver.c src/security/security_apparmor.c src/security/security_dac.c diff --git a/src/Makefile.am b/src/Makefile.am index 45905fa..b24f319 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -1157,6 +1157,7 @@ EXTRA_DIST += $(LIBVIRT_QEMU_SYMBOL_FILE) noinst_LTLIBRARIES += libvirt-net-rpc.la libvirt_net_rpc_la_SOURCES = \ + rpc/virnetmessage.h rpc/virnetmessage.c \ rpc/virnetprotocol.h rpc/virnetprotocol.c libvirt_net_rpc_la_CFLAGS = \ $(AM_CFLAGS) diff --git a/src/rpc/virnetmessage.c b/src/rpc/virnetmessage.c new file mode 100644 index 0000000..1cd3ab3 --- /dev/null +++ b/src/rpc/virnetmessage.c @@ -0,0 +1,365 @@ +/* + * virnetmessage.c: basic RPC message encoding/decoding + * + * Copyright (C) 2010-2011 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + */ + +#include <config.h> + +#include <stdlib.h> + +#include "virnetmessage.h" +#include "memory.h" +#include "virterror_internal.h" +#include "logging.h" + +#define VIR_FROM_THIS VIR_FROM_RPC +#define virNetError(code, ...) \ + virReportErrorHelper(VIR_FROM_THIS, code, __FILE__, \ + __FUNCTION__, __LINE__, __VA_ARGS__) + +virNetMessagePtr virNetMessageNew(void) +{ + virNetMessagePtr msg; + + if (VIR_ALLOC(msg) < 0) { + virReportOOMError(); + return NULL; + } + + VIR_DEBUG("msg=%p", msg); + + return msg; +} + +void virNetMessageFree(virNetMessagePtr msg) +{ + if (!msg) + return; + + VIR_DEBUG("msg=%p", msg); + + VIR_FREE(msg); +} + +void virNetMessageQueuePush(virNetMessagePtr *queue, virNetMessagePtr msg) +{ + virNetMessagePtr tmp = *queue; + + if (tmp) { + while (tmp->next) + tmp = tmp->next; + tmp->next = msg; + } else { + *queue = msg; + } +} + + +virNetMessagePtr virNetMessageQueueServe(virNetMessagePtr *queue) +{ + virNetMessagePtr tmp = *queue; + + if (tmp) { + *queue = tmp->next; + tmp->next = NULL; + } + + return tmp; +} + + +int virNetMessageDecodeLength(virNetMessagePtr msg) +{ + XDR xdr; + unsigned int len; + int ret = -1; + + xdrmem_create(&xdr, msg->buffer, + msg->bufferLength, XDR_DECODE); + if (!xdr_u_int(&xdr, &len)) { + virNetError(VIR_ERR_RPC, "%s", _("Unable to decode message length")); + goto cleanup; + } + msg->bufferOffset = xdr_getpos(&xdr); + + if (len < VIR_NET_MESSAGE_LEN_MAX) { + virNetError(VIR_ERR_RPC, "%s", + _("packet received from server too small")); + goto cleanup; + } + + /* Length includes length word - adjust to real length to read. */ + len -= VIR_NET_MESSAGE_LEN_MAX; + + if (len > VIR_NET_MESSAGE_MAX) { + virNetError(VIR_ERR_RPC, "%s", + _("packet received from server too large")); + goto cleanup; + } + + /* Extend our declared buffer length and carry + on reading the header + payload */ + msg->bufferLength += len; + + VIR_DEBUG("Got length, now need %zu total (%u more)", + msg->bufferLength, len); + + ret = 0; + +cleanup: + xdr_destroy(&xdr); + return ret; +} + + +/* + * @msg: the complete incoming message, whose header to decode + * + * Decodes the header part of the message, but does not + * validate the decoded fields in the header. It expects + * bufferLength to refer to length of the data packet. Upon + * return bufferOffset will refer to the amount of the packet + * consumed by decoding of the header. + * + * returns 0 if successfully decoded, -1 upon fatal error + */ +int virNetMessageDecodeHeader(virNetMessagePtr msg) +{ + XDR xdr; + int ret = -1; + + msg->bufferOffset = VIR_NET_MESSAGE_LEN_MAX; + + /* Parse the header. */ + xdrmem_create(&xdr, + msg->buffer + msg->bufferOffset, + msg->bufferLength - msg->bufferOffset, + XDR_DECODE); + + if (!xdr_virNetMessageHeader(&xdr, &msg->header)) { + virNetError(VIR_ERR_RPC, "%s", _("Unable to decode message header")); + goto cleanup; + } + + msg->bufferOffset += xdr_getpos(&xdr); + + ret = 0; + +cleanup: + xdr_destroy(&xdr); + return ret; +} + + +/* + * @msg: the outgoing message, whose header to encode + * + * Encodes the length word and header of the message, setting the + * message offset ready to encode the payload. Leaves space + * for the length field later. Upon return bufferLength will + * refer to the total available space for message, while + * bufferOffset will refer to current space used by header + * + * returns 0 if successfully encoded, -1 upon fatal error + */ +int virNetMessageEncodeHeader(virNetMessagePtr msg) +{ + XDR xdr; + int ret = -1; + unsigned int len = 0; + + msg->bufferLength = sizeof(msg->buffer); + msg->bufferOffset = 0; + + /* Format the header. */ + xdrmem_create(&xdr, + msg->buffer, + msg->bufferLength, + XDR_ENCODE); + + /* The real value is filled in shortly */ + if (!xdr_u_int(&xdr, &len)) { + virNetError(VIR_ERR_RPC, "%s", _("Unable to encode message length")); + goto cleanup; + } + + if (!xdr_virNetMessageHeader(&xdr, &msg->header)) { + virNetError(VIR_ERR_RPC, "%s", _("Unable to encode message header")); + goto cleanup; + } + + len = xdr_getpos(&xdr); + xdr_setpos(&xdr, 0); + + /* Fill in current length - may be re-written later + * if a payload is added + */ + if (!xdr_u_int(&xdr, &len)) { + virNetError(VIR_ERR_RPC, "%s", _("Unable to re-encode message length")); + goto cleanup; + } + + msg->bufferOffset += len; + + ret = 0; + +cleanup: + xdr_destroy(&xdr); + return ret; +} + + +int virNetMessageEncodePayload(virNetMessagePtr msg, + xdrproc_t filter, + void *data) +{ + XDR xdr; + unsigned int msglen; + + /* Serialise payload of the message. This assumes that + * virNetMessageEncodeHeader has already been run, so + * just appends to that data */ + xdrmem_create(&xdr, msg->buffer + msg->bufferOffset, + msg->bufferLength - msg->bufferOffset, XDR_ENCODE); + + if (!(*filter)(&xdr, data)) { + virNetError(VIR_ERR_RPC, "%s", _("Unable to encode message payload")); + goto error; + } + + /* Get the length stored in buffer. */ + msg->bufferOffset += xdr_getpos(&xdr); + xdr_destroy(&xdr); + + /* Re-encode the length word. */ + VIR_DEBUG("Encode length as %zu", msg->bufferOffset); + xdrmem_create(&xdr, msg->buffer, VIR_NET_MESSAGE_HEADER_XDR_LEN, XDR_ENCODE); + msglen = msg->bufferOffset; + if (!xdr_u_int(&xdr, &msglen)) { + virNetError(VIR_ERR_RPC, "%s", _("Unable to encode message length")); + goto error; + } + xdr_destroy(&xdr); + + msg->bufferLength = msg->bufferOffset; + msg->bufferOffset = 0; + return 0; + +error: + xdr_destroy(&xdr); + return -1; +} + + +int virNetMessageDecodePayload(virNetMessagePtr msg, + xdrproc_t filter, + void *data) +{ + XDR xdr; + + /* Deserialise payload of the message. This assumes that + * virNetMessageDecodeHeader has already been run, so + * just start from after that data */ + xdrmem_create(&xdr, msg->buffer + msg->bufferOffset, + msg->bufferLength - msg->bufferOffset, XDR_DECODE); + + if (!(*filter)(&xdr, data)) { + virNetError(VIR_ERR_RPC, "%s", _("Unable to decode message payload")); + goto error; + } + + /* Get the length stored in buffer. */ + msg->bufferLength += xdr_getpos(&xdr); + xdr_destroy(&xdr); + return 0; + +error: + xdr_destroy(&xdr); + return -1; +} + + +int virNetMessageEncodePayloadRaw(virNetMessagePtr msg, + const char *data, + size_t len) +{ + XDR xdr; + unsigned int msglen; + + if ((msg->bufferLength - msg->bufferOffset) < len) { + virNetError(VIR_ERR_RPC, + _("Stream data too long to send (%zu bytes needed, %zu bytes available)"), + len, (msg->bufferLength - msg->bufferOffset)); + return -1; + } + + memcpy(msg->buffer + msg->bufferOffset, data, len); + msg->bufferOffset += len; + + /* Re-encode the length word. */ + VIR_DEBUG("Encode length as %zu", msg->bufferOffset); + xdrmem_create(&xdr, msg->buffer, VIR_NET_MESSAGE_HEADER_XDR_LEN, XDR_ENCODE); + msglen = msg->bufferOffset; + if (!xdr_u_int(&xdr, &msglen)) { + virNetError(VIR_ERR_RPC, "%s", _("Unable to encode message length")); + goto error; + } + xdr_destroy(&xdr); + + msg->bufferLength = msg->bufferOffset; + msg->bufferOffset = 0; + return 0; + +error: + xdr_destroy(&xdr); + return -1; +} + + +void virNetMessageSaveError(virNetMessageErrorPtr rerr) +{ + /* This func may be called several times & the first + * error is the one we want because we don't want + * cleanup code overwriting the first one. + */ + if (rerr->code != VIR_ERR_OK) + return; + + virErrorPtr verr = virGetLastError(); + if (verr) { + rerr->code = verr->code; + rerr->domain = verr->domain; + rerr->message = verr->message ? malloc(sizeof(char*)) : NULL; + if (rerr->message) *rerr->message = strdup(verr->message); + rerr->level = verr->level; + rerr->str1 = verr->str1 ? malloc(sizeof(char*)) : NULL; + if (rerr->str1) *rerr->str1 = strdup(verr->str1); + rerr->str2 = verr->str2 ? malloc(sizeof(char*)) : NULL; + if (rerr->str2) *rerr->str2 = strdup(verr->str2); + rerr->str3 = verr->str3 ? malloc(sizeof(char*)) : NULL; + if (rerr->str3) *rerr->str3 = strdup(verr->str3); + rerr->int1 = verr->int1; + rerr->int2 = verr->int2; + } else { + rerr->code = VIR_ERR_INTERNAL_ERROR; + rerr->domain = VIR_FROM_RPC; + rerr->message = malloc(sizeof(char*)); + if (rerr->message) *rerr->message = strdup(_("Library function returned error but did not set virError")); + rerr->level = VIR_ERR_ERROR; + } +} diff --git a/src/rpc/virnetmessage.h b/src/rpc/virnetmessage.h new file mode 100644 index 0000000..fbeb257 --- /dev/null +++ b/src/rpc/virnetmessage.h @@ -0,0 +1,82 @@ +/* + * virnetmessage.h: basic RPC message encoding/decoding + * + * Copyright (C) 2010-2011 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + */ + +#ifndef __VIR_NET_MESSAGE_H__ +# define __VIR_NET_MESSAGE_H__ + +# include <stdbool.h> + +# include "virnetprotocol.h" + +typedef struct virNetMessageHeader *virNetMessageHeaderPtr; +typedef struct virNetMessageError *virNetMessageErrorPtr; + +typedef struct _virNetMessage virNetMessage; +typedef virNetMessage *virNetMessagePtr; + +/* Never allocate this (huge) buffer on the stack. Always + * use virNetMessageNew() to allocate on the heap + */ +struct _virNetMessage { + char buffer[VIR_NET_MESSAGE_MAX + VIR_NET_MESSAGE_LEN_MAX]; + size_t bufferLength; + size_t bufferOffset; + + virNetMessageHeader header; + + virNetMessagePtr next; +}; + + +virNetMessagePtr virNetMessageNew(void); + +void virNetMessageFree(virNetMessagePtr msg); + +virNetMessagePtr virNetMessageQueueServe(virNetMessagePtr *queue) + ATTRIBUTE_NONNULL(1); +void virNetMessageQueuePush(virNetMessagePtr *queue, + virNetMessagePtr msg) + ATTRIBUTE_NONNULL(1) ATTRIBUTE_NONNULL(2); + +int virNetMessageEncodeHeader(virNetMessagePtr msg) + ATTRIBUTE_NONNULL(1) ATTRIBUTE_RETURN_CHECK; +int virNetMessageDecodeLength(virNetMessagePtr msg) + ATTRIBUTE_NONNULL(1) ATTRIBUTE_RETURN_CHECK; +int virNetMessageDecodeHeader(virNetMessagePtr msg) + ATTRIBUTE_NONNULL(1) ATTRIBUTE_RETURN_CHECK; + +int virNetMessageEncodePayload(virNetMessagePtr msg, + xdrproc_t filter, + void *data) + ATTRIBUTE_NONNULL(1) ATTRIBUTE_NONNULL(2) ATTRIBUTE_NONNULL(2) ATTRIBUTE_RETURN_CHECK; +int virNetMessageDecodePayload(virNetMessagePtr msg, + xdrproc_t filter, + void *data) + ATTRIBUTE_NONNULL(1) ATTRIBUTE_NONNULL(2) ATTRIBUTE_NONNULL(2) ATTRIBUTE_RETURN_CHECK; + +int virNetMessageEncodePayloadRaw(virNetMessagePtr msg, + const char *buf, + size_t len) + ATTRIBUTE_NONNULL(1) ATTRIBUTE_RETURN_CHECK; + +void virNetMessageSaveError(virNetMessageErrorPtr rerr) + ATTRIBUTE_NONNULL(1); + +#endif /* __VIR_NET_MESSAGE_H__ */ diff --git a/tests/.gitignore b/tests/.gitignore index e3906f0..36115ea 100644 --- a/tests/.gitignore +++ b/tests/.gitignore @@ -30,6 +30,7 @@ statstest storagepoolxml2xmltest storagevolxml2xmltest virbuftest +virnetmessagetest virshtest vmx2xmltest xencapstest diff --git a/tests/Makefile.am b/tests/Makefile.am index bc171d2..f80b98f 100644 --- a/tests/Makefile.am +++ b/tests/Makefile.am @@ -78,7 +78,7 @@ EXTRA_DIST = \ check_PROGRAMS = virshtest conftest sockettest \ nodeinfotest qparamtest virbuftest \ commandtest commandhelper seclabeltest \ - hashtest + hashtest virnetmessagetest if WITH_XEN check_PROGRAMS += xml2sexprtest sexpr2xmltest \ @@ -180,6 +180,7 @@ TESTS = virshtest \ commandtest \ seclabeltest \ hashtest \ + virnetmessagetest \ $(test_scripts) if WITH_XEN @@ -382,6 +383,12 @@ commandhelper_SOURCES = \ commandhelper_CFLAGS = -Dabs_builddir="\"`pwd`\"" commandhelper_LDADD = $(LDADDS) +virnetmessagetest_SOURCES = \ + virnetmessagetest.c testutils.h testutils.c +virnetmessagetest_CFLAGS = -Dabs_builddir="\"$(abs_builddir)\"" +virnetmessagetest_LDADD = $(LDADDS) + + seclabeltest_SOURCES = \ seclabeltest.c seclabeltest_LDADD = ../src/libvirt_driver_security.la $(LDADDS) diff --git a/tests/testutils.c b/tests/testutils.c index bc89690..d87347d 100644 --- a/tests/testutils.c +++ b/tests/testutils.c @@ -370,6 +370,68 @@ int virtTestDifference(FILE *stream, return 0; } +/** + * @param stream: output stream write to differences to + * @param expect: expected output text + * @param actual: actual output text + * + * Display expected and actual output text, trimmed to + * first and last characters at which differences occur + */ +int virtTestDifferenceBin(FILE *stream, + const char *expect, + const char *actual, + size_t length) +{ + size_t start = 0, end = length; + ssize_t i; + + if (!virTestGetDebug()) + return 0; + + if (virTestGetDebug() < 2) { + /* Skip to first character where they differ */ + for (i = 0 ; i < length ; i++) { + if (expect[i] != actual[i]) { + start = i; + break; + } + } + + /* Work backwards to last character where they differ */ + for (i = (length -1) ; i >= 0 ; i--) { + if (expect[i] != actual[i]) { + end = i; + break; + } + } + } + /* Round to nearest boundary of 4 */ + start -= (start % 4); + end += 4 - (end % 4); + + /* Show the trimmed differences */ + fprintf(stream, "\nExpect [ Region %d-%d", (int)start, (int)end); + for (i = start; i < end ; i++) { + if ((i % 4) == 0) + fprintf(stream, "\n "); + fprintf(stream, "0x%02x, ", ((int)expect[i])&0xff); + } + fprintf(stream, "]\n"); + fprintf(stream, "Actual [ Region %d-%d", (int)start, (int)end); + for (i = start; i < end ; i++) { + if ((i % 4) == 0) + fprintf(stream, "\n "); + fprintf(stream, "0x%02x, ", ((int)actual[i])&0xff); + } + fprintf(stream, "]\n"); + + /* Pad to line up with test name ... in virTestRun */ + fprintf(stream, " ... "); + + return 0; +} + #if TEST_OOM static void virtTestErrorFuncQuiet(void *data ATTRIBUTE_UNUSED, diff --git a/tests/testutils.h b/tests/testutils.h index e8f4153..03d8dc6 100644 --- a/tests/testutils.h +++ b/tests/testutils.h @@ -36,6 +36,10 @@ int virtTestClearLineRegex(const char *pattern, int virtTestDifference(FILE *stream, const char *expect, const char *actual); +int virtTestDifferenceBin(FILE *stream, + const char *expect, + const char *actual, + size_t length); unsigned int virTestGetDebug(void); unsigned int virTestGetVerbose(void); diff --git a/tests/virnetmessagetest.c b/tests/virnetmessagetest.c new file mode 100644 index 0000000..384ab43 --- /dev/null +++ b/tests/virnetmessagetest.c @@ -0,0 +1,496 @@ +/* + * Copyright (C) 2011 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#include <config.h> + +#include <stdlib.h> +#include <signal.h> + +#include "testutils.h" +#include "util.h" +#include "virterror_internal.h" +#include "memory.h" +#include "logging.h" +#include "ignore-value.h" + +#include "rpc/virnetmessage.h" + +#define VIR_FROM_THIS VIR_FROM_RPC + +static int testMessageHeaderEncode(const void *args ATTRIBUTE_UNUSED) +{ + static virNetMessage msg; + static const char expect[] = { + 0x00, 0x00, 0x00, 0x1c, /* Length */ + 0x11, 0x22, 0x33, 0x44, /* Program */ + 0x00, 0x00, 0x00, 0x01, /* Version */ + 0x00, 0x00, 0x06, 0x66, /* Procedure */ + 0x00, 0x00, 0x00, 0x00, /* Type */ + 0x00, 0x00, 0x00, 0x99, /* Serial */ + 0x00, 0x00, 0x00, 0x00, /* Status */ + }; + memset(&msg, 0, sizeof(msg)); + + msg.header.prog = 0x11223344; + msg.header.vers = 0x01; + msg.header.proc = 0x666; + msg.header.type = VIR_NET_CALL; + msg.header.serial = 0x99; + msg.header.status = VIR_NET_OK; + + if (virNetMessageEncodeHeader(&msg) < 0) + return -1; + + if (ARRAY_CARDINALITY(expect) != msg.bufferOffset) { + VIR_DEBUG("Expect message offset %zu got %zu", + sizeof(expect), msg.bufferOffset); + return -1; + } + + if (msg.bufferLength != sizeof(msg.buffer)) { + VIR_DEBUG("Expect message offset %zu got %zu", + sizeof(msg.buffer), msg.bufferLength); + return -1; + } + + if (memcmp(expect, msg.buffer, sizeof(expect)) != 0) { + virtTestDifferenceBin(stderr, expect, msg.buffer, sizeof(expect)); + return -1; + } + + return 0; +} + +static int testMessageHeaderDecode(const void *args ATTRIBUTE_UNUSED) +{ + static virNetMessage msg = { + .bufferOffset = 0, + .bufferLength = 0x4, + .buffer = { + 0x00, 0x00, 0x00, 0x1c, /* Length */ + 0x11, 0x22, 0x33, 0x44, /* Program */ + 0x00, 0x00, 0x00, 0x01, /* Version */ + 0x00, 0x00, 0x06, 0x66, /* Procedure */ + 0x00, 0x00, 0x00, 0x01, /* Type */ + 0x00, 0x00, 0x00, 0x99, /* Serial */ + 0x00, 0x00, 0x00, 0x01, /* Status */ + }, + .header = { 0, 0, 0, 0, 0, 0 }, + }; + + msg.header.prog = 0x11223344; + msg.header.vers = 0x01; + msg.header.proc = 0x666; + msg.header.type = VIR_NET_CALL; + msg.header.serial = 0x99; + msg.header.status = VIR_NET_OK; + + if (virNetMessageDecodeLength(&msg) < 0) { + VIR_DEBUG("Failed to decode message header"); + return -1; + } + + if (msg.bufferOffset != 0x4) { + VIR_DEBUG("Expecting offset %zu got %zu", + (size_t)4, msg.bufferOffset); + return -1; + } + + if (msg.bufferLength != 0x1c) { + VIR_DEBUG("Expecting length %zu got %zu", + (size_t)0x1c, msg.bufferLength); + return -1; + } + + if (virNetMessageDecodeHeader(&msg) < 0) { + VIR_DEBUG("Failed to decode message header"); + return -1; + } + + if (msg.bufferOffset != msg.bufferLength) { + VIR_DEBUG("Expect message offset %zu got %zu", + msg.bufferOffset, msg.bufferLength); + return -1; + } + + if (msg.header.prog != 0x11223344) { + VIR_DEBUG("Expect prog %d got %d", + 0x11223344, msg.header.prog); + return -1; + } + if (msg.header.vers != 0x1) { + VIR_DEBUG("Expect vers %d got %d", + 0x11223344, msg.header.vers); + return -1; + } + if (msg.header.proc != 0x666) { + VIR_DEBUG("Expect proc %d got %d", + 0x666, msg.header.proc); + return -1; + } + if (msg.header.type != VIR_NET_REPLY) { + VIR_DEBUG("Expect type %d got %d", + VIR_NET_REPLY, msg.header.type); + return -1; + } + if (msg.header.serial != 0x99) { + VIR_DEBUG("Expect serial %d got %d", + 0x99, msg.header.serial); + return -1; + } + if (msg.header.status != VIR_NET_ERROR) { + VIR_DEBUG("Expect status %d got %d", + VIR_NET_ERROR, msg.header.status); + return -1; + } + + return 0; +} + +static int testMessagePayloadEncode(const void *args ATTRIBUTE_UNUSED) +{ + virNetMessageError err; + static virNetMessage msg; + static const char expect[] = { + 0x00, 0x00, 0x00, 0x74, /* Length */ + 0x11, 0x22, 0x33, 0x44, /* Program */ + 0x00, 0x00, 0x00, 0x01, /* Version */ + 0x00, 0x00, 0x06, 0x66, /* Procedure */ + 0x00, 0x00, 0x00, 0x02, /* Type */ + 0x00, 0x00, 0x00, 0x99, /* Serial */ + 0x00, 0x00, 0x00, 0x01, /* Status */ + + 0x00, 0x00, 0x00, 0x01, /* Error code */ + 0x00, 0x00, 0x00, 0x07, /* Error domain */ + 0x00, 0x00, 0x00, 0x01, /* Error message pointer */ + 0x00, 0x00, 0x00, 0x0b, /* Error message length */ + 'H', 'e', 'l', 'l', /* Error message string */ + 'o', ' ', 'W', 'o', + 'r', 'l', 'd', '\0', + 0x00, 0x00, 0x00, 0x02, /* Error level */ + 0x00, 0x00, 0x00, 0x00, /* Error domain pointer */ + 0x00, 0x00, 0x00, 0x01, /* Error str1 pointer */ + 0x00, 0x00, 0x00, 0x03, /* Error str1 length */ + 'O', 'n', 'e', '\0', /* Error str1 message */ + 0x00, 0x00, 0x00, 0x01, /* Error str2 pointer */ + 0x00, 0x00, 0x00, 0x03, /* Error str2 length */ + 'T', 'w', 'o', '\0', /* Error str2 message */ + 0x00, 0x00, 0x00, 0x01, /* Error str3 pointer */ + 0x00, 0x00, 0x00, 0x05, /* Error str3 length */ + 'T', 'h', 'r', 'e', /* Error str3 message */ + 'e', '\0', '\0', '\0', + 0x00, 0x00, 0x00, 0x01, /* Error int1 */ + 0x00, 0x00, 0x00, 0x02, /* Error int2 */ + 0x00, 0x00, 0x00, 0x00, /* Error network pointer */ + }; + memset(&msg, 0, sizeof(msg)); + memset(&err, 0, sizeof(err)); + + err.code = VIR_ERR_INTERNAL_ERROR; + err.domain = VIR_FROM_RPC; + if (VIR_ALLOC(err.message) < 0) + return -1; + *err.message = strdup("Hello World"); + err.level = VIR_ERR_ERROR; + if (VIR_ALLOC(err.str1) < 0) + return -1; + *err.str1 = strdup("One"); + if (VIR_ALLOC(err.str2) < 0) + return -1; + *err.str2 = strdup("Two"); + if (VIR_ALLOC(err.str3) < 0) + return -1; + *err.str3 = strdup("Three"); + err.int1 = 1; + err.int2 = 2; + + msg.header.prog = 0x11223344; + msg.header.vers = 0x01; + msg.header.proc = 0x666; + msg.header.type = VIR_NET_MESSAGE; + msg.header.serial = 0x99; + msg.header.status = VIR_NET_ERROR; + + if (virNetMessageEncodeHeader(&msg) < 0) + return -1; + + if (virNetMessageEncodePayload(&msg, (xdrproc_t)xdr_virNetMessageError, &err) < 0) + return -1; + + if (ARRAY_CARDINALITY(expect) != msg.bufferLength) { + VIR_DEBUG("Expect message length %zu got %zu", + sizeof(expect), msg.bufferLength); + return -1; + } + + if (msg.bufferOffset != 0) { + VIR_DEBUG("Expect message offset 0 got %zu", + msg.bufferOffset); + return -1; + } + + if (memcmp(expect, msg.buffer, sizeof(expect)) != 0) { + virtTestDifferenceBin(stderr, expect, msg.buffer, sizeof(expect)); + return -1; + } + + return 0; +} + +static int testMessagePayloadDecode(const void *args ATTRIBUTE_UNUSED) +{ + virNetMessageError err; + static virNetMessage msg = { + .bufferOffset = 0, + .bufferLength = 0x4, + .buffer = { + 0x00, 0x00, 0x00, 0x74, /* Length */ + 0x11, 0x22, 0x33, 0x44, /* Program */ + 0x00, 0x00, 0x00, 0x01, /* Version */ + 0x00, 0x00, 0x06, 0x66, /* Procedure */ + 0x00, 0x00, 0x00, 0x02, /* Type */ + 0x00, 0x00, 0x00, 0x99, /* Serial */ + 0x00, 0x00, 0x00, 0x01, /* Status */ + + 0x00, 0x00, 0x00, 0x01, /* Error code */ + 0x00, 0x00, 0x00, 0x07, /* Error domain */ + 0x00, 0x00, 0x00, 0x01, /* Error message pointer */ + 0x00, 0x00, 0x00, 0x0b, /* Error message length */ + 'H', 'e', 'l', 'l', /* Error message string */ + 'o', ' ', 'W', 'o', + 'r', 'l', 'd', '\0', + 0x00, 0x00, 0x00, 0x02, /* Error level */ + 0x00, 0x00, 0x00, 0x00, /* Error domain pointer */ + 0x00, 0x00, 0x00, 0x01, /* Error str1 pointer */ + 0x00, 0x00, 0x00, 0x03, /* Error str1 length */ + 'O', 'n', 'e', '\0', /* Error str1 message */ + 0x00, 0x00, 0x00, 0x01, /* Error str2 pointer */ + 0x00, 0x00, 0x00, 0x03, /* Error str2 length */ + 'T', 'w', 'o', '\0', /* Error str2 message */ + 0x00, 0x00, 0x00, 0x01, /* Error str3 pointer */ + 0x00, 0x00, 0x00, 0x05, /* Error str3 length */ + 'T', 'h', 'r', 'e', /* Error str3 message */ + 'e', '\0', '\0', '\0', + 0x00, 0x00, 0x00, 0x01, /* Error int1 */ + 0x00, 0x00, 0x00, 0x02, /* Error int2 */ + 0x00, 0x00, 0x00, 0x00, /* Error network pointer */ + }, + .header = { 0, 0, 0, 0, 0, 0 }, + }; + memset(&err, 0, sizeof(err)); + + if (virNetMessageDecodeLength(&msg) < 0) { + VIR_DEBUG("Failed to decode message header"); + return -1; + } + + if (msg.bufferOffset != 0x4) { + VIR_DEBUG("Expecting offset %zu got %zu", + (size_t)4, msg.bufferOffset); + return -1; + } + + if (msg.bufferLength != 0x74) { + VIR_DEBUG("Expecting length %zu got %zu", + (size_t)0x74, msg.bufferLength); + return -1; + } + + if (virNetMessageDecodeHeader(&msg) < 0) { + VIR_DEBUG("Failed to decode message header"); + return -1; + } + + if (msg.bufferOffset != 28) { + VIR_DEBUG("Expect message offset %zu got %zu", + msg.bufferOffset, (size_t)28); + return -1; + } + + if (msg.bufferLength != 0x74) { + VIR_DEBUG("Expecting length %zu got %zu", + (size_t)0x1c, msg.bufferLength); + return -1; + } + + if (virNetMessageDecodePayload(&msg, (xdrproc_t)xdr_virNetMessageError, &err) < 0) { + VIR_DEBUG("Failed to decode message payload"); + return -1; + } + + if (err.code != VIR_ERR_INTERNAL_ERROR) { + VIR_DEBUG("Expect code %d got %d", + VIR_ERR_INTERNAL_ERROR, err.code); + return -1; + } + + if (err.domain != VIR_FROM_RPC) { + VIR_DEBUG("Expect domain %d got %d", + VIR_ERR_RPC, err.domain); + return -1; + } + + if (err.message == NULL || + STRNEQ(*err.message, "Hello World")) { + VIR_DEBUG("Expect str1 'Hello World' got %s", + err.message ? *err.message : "(null)"); + return -1; + } + + if (err.dom != NULL) { + VIR_DEBUG("Expect NULL dom"); + return -1; + } + + if (err.level != VIR_ERR_ERROR) { + VIR_DEBUG("Expect leve %d got %d", + VIR_ERR_ERROR, err.level); + return -1; + } + + if (err.str1 == NULL || + STRNEQ(*err.str1, "One")) { + VIR_DEBUG("Expect str1 'One' got %s", + err.str1 ? *err.str1 : "(null)"); + return -1; + } + + if (err.str2 == NULL || + STRNEQ(*err.str2, "Two")) { + VIR_DEBUG("Expect str3 'Two' got %s", + err.str2 ? *err.str2 : "(null)"); + return -1; + } + + if (err.str3 == NULL || + STRNEQ(*err.str3, "Three")) { + VIR_DEBUG("Expect str3 'Three' got %s", + err.str3 ? *err.str3 : "(null)"); + return -1; + } + + if (err.int1 != 1) { + VIR_DEBUG("Expect int1 1 got %d", + err.int1); + return -1; + } + + if (err.int2 != 2) { + VIR_DEBUG("Expect int2 2 got %d", + err.int2); + return -1; + } + + if (err.net != NULL) { + VIR_DEBUG("Expect NULL network"); + return -1; + } + + xdr_free((xdrproc_t)xdr_virNetMessageError, (void*)&err); + return 0; +} + +static int testMessagePayloadStreamEncode(const void *args ATTRIBUTE_UNUSED) +{ + char stream[] = "The quick brown fox jumps over the lazy dog"; + static virNetMessage msg; + static const char expect[] = { + 0x00, 0x00, 0x00, 0x47, /* Length */ + 0x11, 0x22, 0x33, 0x44, /* Program */ + 0x00, 0x00, 0x00, 0x01, /* Version */ + 0x00, 0x00, 0x06, 0x66, /* Procedure */ + 0x00, 0x00, 0x00, 0x03, /* Type */ + 0x00, 0x00, 0x00, 0x99, /* Serial */ + 0x00, 0x00, 0x00, 0x02, /* Status */ + + 'T', 'h', 'e', ' ', + 'q', 'u', 'i', 'c', + 'k', ' ', 'b', 'r', + 'o', 'w', 'n', ' ', + 'f', 'o', 'x', ' ', + 'j', 'u', 'm', 'p', + 's', ' ', 'o', 'v', + 'e', 'r', ' ', 't', + 'h', 'e', ' ', 'l', + 'a', 'z', 'y', ' ', + 'd', 'o', 'g', + }; + memset(&msg, 0, sizeof(msg)); + + msg.header.prog = 0x11223344; + msg.header.vers = 0x01; + msg.header.proc = 0x666; + msg.header.type = VIR_NET_STREAM; + msg.header.serial = 0x99; + msg.header.status = VIR_NET_CONTINUE; + + if (virNetMessageEncodeHeader(&msg) < 0) + return -1; + + if (virNetMessageEncodePayloadRaw(&msg, stream, strlen(stream)) < 0) + return -1; + + if (ARRAY_CARDINALITY(expect) != msg.bufferLength) { + VIR_DEBUG("Expect message length %zu got %zu", + sizeof(expect), msg.bufferLength); + return -1; + } + + if (msg.bufferOffset != 0) { + VIR_DEBUG("Expect message offset 0 got %zu", + msg.bufferOffset); + return -1; + } + + if (memcmp(expect, msg.buffer, sizeof(expect)) != 0) { + virtTestDifferenceBin(stderr, expect, msg.buffer, sizeof(expect)); + return -1; + } + + return 0; +} + + +static int +mymain(void) +{ + int ret = 0; + + signal(SIGPIPE, SIG_IGN); + + if (virtTestRun("Message Header Encode", 1, testMessageHeaderEncode, NULL) < 0) + ret = -1; + + if (virtTestRun("Message Header Decode", 1, testMessageHeaderDecode, NULL) < 0) + ret = -1; + + if (virtTestRun("Message Payload Encode", 1, testMessagePayloadEncode, NULL) < 0) + ret = -1; + + if (virtTestRun("Message Payload Decode", 1, testMessagePayloadDecode, NULL) < 0) + ret = -1; + + if (virtTestRun("Message Payload Stream Encode", 1, testMessagePayloadStreamEncode, NULL) < 0) + ret = -1; + + return (ret==0 ? EXIT_SUCCESS : EXIT_FAILURE); +} + +VIRT_TEST_MAIN(mymain) -- 1.7.4.4

Introduces a simple wrapper around the raw POSIX sockets APIs and name resolution APIs. Allows for easy creation of client and server sockets with correct usage of name resolution APIs for protocol agnostic socket setup. It can listen for UNIX and TCP stream sockets. It can connect to UNIX, TCP streams directly, or indirectly to UNIX sockets via an SSH tunnel or external command * src/Makefile.am: Add to libvirt-net-rpc.la * src/rpc/virnetsocket.c, src/rpc/virnetsocket.h: Generic sockets APIs * tests/Makefile.am: Add socket test * tests/virnetsockettest.c: New test case * tests/testutils.c: Avoid overriding LIBVIRT_DEBUG settings * tests/ssh.c: Dumb helper program for SSH tunnelling tests --- cfg.mk | 1 + configure.ac | 2 +- po/POTFILES.in | 1 + src/Makefile.am | 3 +- src/rpc/virnetsocket.c | 836 ++++++++++++++++++++++++++++++++++++++++++++++ src/rpc/virnetsocket.h | 107 ++++++ tests/.gitignore | 2 + tests/Makefile.am | 12 +- tests/ssh.c | 54 +++ tests/testutils.c | 8 +- tests/virnetsockettest.c | 522 +++++++++++++++++++++++++++++ 11 files changed, 1542 insertions(+), 6 deletions(-) create mode 100644 src/rpc/virnetsocket.c create mode 100644 src/rpc/virnetsocket.h create mode 100644 tests/ssh.c create mode 100644 tests/virnetsockettest.c diff --git a/cfg.mk b/cfg.mk index cf30929..c7d6b51 100644 --- a/cfg.mk +++ b/cfg.mk @@ -126,6 +126,7 @@ useless_free_options = \ --name=virJSONValueFree \ --name=virLastErrFreeData \ --name=virNetMessageFree \ + --name=virNetSocketFree \ --name=virNWFilterDefFree \ --name=virNWFilterEntryFree \ --name=virNWFilterHashTableFree \ diff --git a/configure.ac b/configure.ac index e17e7af..b46faf7 100644 --- a/configure.ac +++ b/configure.ac @@ -135,7 +135,7 @@ LIBS=$old_libs dnl Availability of various common headers (non-fatal if missing). AC_CHECK_HEADERS([pwd.h paths.h regex.h sys/un.h \ sys/poll.h syslog.h mntent.h net/ethernet.h linux/magic.h \ - sys/un.h sys/syscall.h]) + sys/un.h sys/syscall.h netinet/tcp.h ifaddrs.h]) AC_CHECK_LIB([intl],[gettext],[]) diff --git a/po/POTFILES.in b/po/POTFILES.in index a6048ec..adb2bbd 100644 --- a/po/POTFILES.in +++ b/po/POTFILES.in @@ -68,6 +68,7 @@ src/qemu/qemu_process.c src/remote/remote_client_bodies.h src/remote/remote_driver.c src/rpc/virnetmessage.c +src/rpc/virnetsocket.c src/secret/secret_driver.c src/security/security_apparmor.c src/security/security_dac.c diff --git a/src/Makefile.am b/src/Makefile.am index b24f319..639a41e 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -1158,7 +1158,8 @@ noinst_LTLIBRARIES += libvirt-net-rpc.la libvirt_net_rpc_la_SOURCES = \ rpc/virnetmessage.h rpc/virnetmessage.c \ - rpc/virnetprotocol.h rpc/virnetprotocol.c + rpc/virnetprotocol.h rpc/virnetprotocol.c \ + rpc/virnetsocket.h rpc/virnetsocket.c libvirt_net_rpc_la_CFLAGS = \ $(AM_CFLAGS) libvirt_net_rpc_la_LDFLAGS = \ diff --git a/src/rpc/virnetsocket.c b/src/rpc/virnetsocket.c new file mode 100644 index 0000000..f7d1095 --- /dev/null +++ b/src/rpc/virnetsocket.c @@ -0,0 +1,836 @@ +/* + * virnetsocket.c: generic network socket handling + * + * Copyright (C) 2006-2011 Red Hat, Inc. + * Copyright (C) 2006 Daniel P. Berrange + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#include <config.h> + +#include <sys/stat.h> +#include <sys/socket.h> +#include <unistd.h> +#include <sys/wait.h> + +#ifdef HAVE_NETINET_TCP_H +# include <netinet/tcp.h> +#endif + +#include "virnetsocket.h" +#include "util.h" +#include "memory.h" +#include "virterror_internal.h" +#include "logging.h" +#include "files.h" +#include "event.h" + +#define VIR_FROM_THIS VIR_FROM_RPC + +#define virNetError(code, ...) \ + virReportErrorHelper(VIR_FROM_THIS, code, __FILE__, \ + __FUNCTION__, __LINE__, __VA_ARGS__) + + +struct _virNetSocket { + int fd; + int watch; + pid_t pid; + int errfd; + bool client; + virNetSocketIOFunc func; + void *opaque; + virSocketAddr localAddr; + virSocketAddr remoteAddr; + char *localAddrStr; + char *remoteAddrStr; +}; + + +#ifndef WIN32 +static int virNetSocketForkDaemon(const char *binary) +{ + int ret; + virCommandPtr cmd = virCommandNewArgList(binary, + "--timeout=30", + NULL); + + virCommandAddEnvPassCommon(cmd); + virCommandClearCaps(cmd); + virCommandDaemonize(cmd); + ret = virCommandRun(cmd, NULL); + virCommandFree(cmd); + return ret; +} +#endif + + +static virNetSocketPtr virNetSocketNew(virSocketAddrPtr localAddr, + virSocketAddrPtr remoteAddr, + bool isClient, + int fd, int errfd, pid_t pid) +{ + virNetSocketPtr sock; + int no_slow_start = 1; + + VIR_DEBUG("localAddr=%p remoteAddr=%p fd=%d errfd=%d pid=%d", + localAddr, remoteAddr, + fd, errfd, pid); + + if (virSetCloseExec(fd) < 0) { + virReportSystemError(errno, "%s", + _("Unable to set close-on-exec flag")); + return NULL; + } + if (virSetNonBlock(fd) < 0) { + virReportSystemError(errno, "%s", + _("Unable to enable non-blocking flag")); + return NULL; + } + + if (VIR_ALLOC(sock) < 0) { + virReportOOMError(); + return NULL; + } + + if (localAddr) + sock->localAddr = *localAddr; + if (remoteAddr) + sock->remoteAddr = *remoteAddr; + sock->fd = fd; + sock->errfd = errfd; + sock->pid = pid; + + /* Disable nagle for TCP sockets */ + if (sock->localAddr.data.sa.sa_family == AF_INET || + sock->localAddr.data.sa.sa_family == AF_INET6) { + if (setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, + &no_slow_start, + sizeof(no_slow_start)) < 0) { + virReportSystemError(errno, "%s", + _("Unable to disable nagle algorithm")); + goto error; + } + } + + + if (localAddr && + !(sock->localAddrStr = virSocketFormatAddrFull(localAddr, true, ";"))) + goto error; + + if (remoteAddr && + !(sock->remoteAddrStr = virSocketFormatAddrFull(remoteAddr, true, ";"))) + goto error; + + sock->client = isClient; + + VIR_DEBUG("sock=%p localAddrStr=%s remoteAddrStr=%s", + sock, NULLSTR(sock->localAddrStr), NULLSTR(sock->remoteAddrStr)); + + return sock; + +error: + sock->fd = sock->errfd = -1; /* Caller owns fd/errfd on failure */ + virNetSocketFree(sock); + return NULL; +} + + +int virNetSocketNewListenTCP(const char *nodename, + const char *service, + virNetSocketPtr **retsocks, + size_t *nretsocks) +{ + virNetSocketPtr *socks = NULL; + size_t nsocks = 0; + struct addrinfo *ai = NULL; + struct addrinfo hints; + int fd = -1; + int i; + + *retsocks = NULL; + *nretsocks = 0; + + memset(&hints, 0, sizeof hints); + hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG; + hints.ai_socktype = SOCK_STREAM; + + int e = getaddrinfo(nodename, service, &hints, &ai); + if (e != 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Unable to resolve address '%s' service '%s': %s"), + nodename, service, gai_strerror(e)); + return -1; + } + + struct addrinfo *runp = ai; + while (runp) { + virSocketAddr addr; + + memset(&addr, 0, sizeof(addr)); + + if ((fd = socket(runp->ai_family, runp->ai_socktype, + runp->ai_protocol)) < 0) { + virReportSystemError(errno, "%s", _("Unable to create socket")); + goto error; + } + + int opt = 1; + if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof opt) < 0) { + virReportSystemError(errno, "%s", _("Unable to enable port reuse")); + goto error; + } + +#ifdef IPV6_V6ONLY + if (runp->ai_family == PF_INET6) { + int on = 1; + /* + * Normally on Linux an INET6 socket will bind to the INET4 + * address too. If getaddrinfo returns results with INET4 + * first though, this will result in INET6 binding failing. + * We can trivially cope with multiple server sockets, so + * we force it to only listen on IPv6 + */ + if (setsockopt(fd, IPPROTO_IPV6,IPV6_V6ONLY, + (void*)&on, sizeof on) < 0) { + virReportSystemError(errno, "%s", + _("Unable to force bind to IPv6 only")); + goto error; + } + } +#endif + + if (bind(fd, runp->ai_addr, runp->ai_addrlen) < 0) { + if (errno != EADDRINUSE) { + virReportSystemError(errno, "%s", _("Unable to bind to port")); + goto error; + } + VIR_FORCE_CLOSE(fd); + continue; + } + + addr.len = sizeof(addr.data); + if (getsockname(fd, &addr.data.sa, &addr.len) < 0) { + virReportSystemError(errno, "%s", _("Unable to get local socket name")); + goto error; + } + + VIR_DEBUG("%p f=%d f=%d", &addr, runp->ai_family, addr.data.sa.sa_family); + + if (VIR_EXPAND_N(socks, nsocks, 1) < 0) { + virReportOOMError(); + goto error; + } + + if (!(socks[nsocks-1] = virNetSocketNew(&addr, NULL, false, fd, -1, 0))) + goto error; + runp = runp->ai_next; + fd = -1; + } + + freeaddrinfo(ai); + + *retsocks = socks; + *nretsocks = nsocks; + return 0; + +error: + for (i = 0 ; i < nsocks ; i++) + virNetSocketFree(socks[i]); + VIR_FREE(socks); + freeaddrinfo(ai); + VIR_FORCE_CLOSE(fd); + return -1; +} + + +#if HAVE_SYS_UN_H +int virNetSocketNewListenUNIX(const char *path, + mode_t mask, + gid_t grp, + virNetSocketPtr *retsock) +{ + virSocketAddr addr; + mode_t oldmask; + int fd; + + *retsock = NULL; + + memset(&addr, 0, sizeof(addr)); + + addr.len = sizeof(addr.data.un); + + if ((fd = socket(PF_UNIX, SOCK_STREAM, 0)) < 0) { + virReportSystemError(errno, "%s", _("Failed to create socket")); + goto error; + } + + addr.data.un.sun_family = AF_UNIX; + if (virStrcpyStatic(addr.data.un.sun_path, path) == NULL) { + virReportSystemError(ENOMEM, _("Path %s too long for unix socket"), path); + goto error; + } + if (addr.data.un.sun_path[0] == '@') + addr.data.un.sun_path[0] = '\0'; + else + unlink(addr.data.un.sun_path); + + oldmask = umask(~mask); + + if (bind(fd, &addr.data.sa, addr.len) < 0) { + umask(oldmask); + virReportSystemError(errno, + _("Failed to bind socket to '%s'"), + path); + goto error; + } + umask(oldmask); + + /* chown() doesn't work for abstract sockets but we use them only + * if libvirtd runs unprivileged + */ + if (grp != 0 && chown(path, -1, grp)) { + virReportSystemError(errno, + _("Failed to change group ID of '%s' to %d"), + path, grp); + goto error; + } + + if (!(*retsock = virNetSocketNew(&addr, NULL, false, fd, -1, 0))) + goto error; + + return 0; + +error: + if (path[0] != '@') + unlink(path); + VIR_FORCE_CLOSE(fd); + return -1; +} +#else +int virNetSocketNewListenUNIX(const char *path ATTRIBUTE_UNUSED, + mode_t mask ATTRIBUTE_UNUSED, + gid_t grp ATTRIBUTE_UNUSED, + virNetSocketPtr *retsock ATTRIBUTE_UNUSED) +{ + virReportSystemError(ENOSYS, "%s", + _("UNIX sockets are not supported on this platform")); + return -1; +} +#endif + + +int virNetSocketNewConnectTCP(const char *nodename, + const char *service, + virNetSocketPtr *retsock) +{ + struct addrinfo *ai = NULL; + struct addrinfo hints; + int fd = -1; + virSocketAddr localAddr; + virSocketAddr remoteAddr; + struct addrinfo *runp; + int savedErrno = ENOENT; + + *retsock = NULL; + + memset(&localAddr, 0, sizeof(localAddr)); + memset(&remoteAddr, 0, sizeof(remoteAddr)); + + memset(&hints, 0, sizeof hints); + hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG; + hints.ai_socktype = SOCK_STREAM; + + int e = getaddrinfo(nodename, service, &hints, &ai); + if (e != 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Unable to resolve address '%s' service '%s': %s"), + nodename, service, gai_strerror (e)); + return -1; + } + + runp = ai; + while (runp) { + int opt = 1; + + if ((fd = socket(runp->ai_family, runp->ai_socktype, + runp->ai_protocol)) < 0) { + virReportSystemError(errno, "%s", _("Unable to create socket")); + goto error; + } + + setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof opt); + + if (connect(fd, runp->ai_addr, runp->ai_addrlen) >= 0) + break; + + savedErrno = errno; + VIR_FORCE_CLOSE(fd); + runp = runp->ai_next; + } + + if (fd == -1) { + virReportSystemError(savedErrno, + _("unable to connect to server at '%s:%s'"), + nodename, service); + goto error; + } + + localAddr.len = sizeof(localAddr.data); + if (getsockname(fd, &localAddr.data.sa, &localAddr.len) < 0) { + virReportSystemError(errno, "%s", _("Unable to get local socket name")); + goto error; + } + + remoteAddr.len = sizeof(remoteAddr.data); + if (getpeername(fd, &remoteAddr.data.sa, &remoteAddr.len) < 0) { + virReportSystemError(errno, "%s", _("Unable to get remote socket name")); + goto error; + } + + if (!(*retsock = virNetSocketNew(&localAddr, &remoteAddr, true, fd, -1, 0))) + goto error; + + freeaddrinfo(ai); + + return 0; + +error: + freeaddrinfo(ai); + VIR_FORCE_CLOSE(fd); + return -1; +} + + +#if HAVE_SYS_UN_H +int virNetSocketNewConnectUNIX(const char *path, + bool spawnDaemon, + const char *binary, + virNetSocketPtr *retsock) +{ + virSocketAddr localAddr; + virSocketAddr remoteAddr; + int fd; + int retries = 0; + + memset(&localAddr, 0, sizeof(localAddr)); + memset(&remoteAddr, 0, sizeof(remoteAddr)); + + remoteAddr.len = sizeof(remoteAddr.data.un); + + if ((fd = socket(PF_UNIX, SOCK_STREAM, 0)) < 0) { + virReportSystemError(errno, "%s", _("Failed to create socket")); + goto error; + } + + remoteAddr.data.un.sun_family = AF_UNIX; + if (virStrcpyStatic(remoteAddr.data.un.sun_path, path) == NULL) { + virReportSystemError(ENOMEM, _("Path %s too long for unix socket"), path); + goto error; + } + if (remoteAddr.data.un.sun_path[0] == '@') + remoteAddr.data.un.sun_path[0] = '\0'; + +retry: + if (connect(fd, &remoteAddr.data.sa, remoteAddr.len) < 0) { + if (errno == ECONNREFUSED && spawnDaemon && retries < 20) { + if (retries == 0 && + virNetSocketForkDaemon(binary) < 0) + goto error; + + retries++; + usleep(1000 * 100 * retries); + goto retry; + } + + virReportSystemError(errno, + _("Failed to connect socket to '%s'"), + path); + goto error; + } + + localAddr.len = sizeof(localAddr.data); + if (getsockname(fd, &localAddr.data.sa, &localAddr.len) < 0) { + virReportSystemError(errno, "%s", _("Unable to get local socket name")); + goto error; + } + + if (!(*retsock = virNetSocketNew(&localAddr, &remoteAddr, true, fd, -1, 0))) + goto error; + + return 0; + +error: + VIR_FORCE_CLOSE(fd); + return -1; +} +#else +int virNetSocketNewConnectUNIX(const char *path ATTRIBUTE_UNUSED, + bool spawnDaemon ATTRIBUTE_UNUSED, + const char *binary ATTRIBUTE_UNUSED, + virNetSocketPtr *retsock ATTRIBUTE_UNUSED) +{ + virReportSystemError(ENOSYS, "%s", + _("UNIX sockets are not supported on this platform")); + return -1; +} +#endif + + +#ifndef WIN32 +int virNetSocketNewConnectCommand(virCommandPtr cmd, + virNetSocketPtr *retsock) +{ + pid_t pid = 0; + int sv[2]; + int errfd[2]; + + *retsock = NULL; + + /* Fork off the external process. Use socketpair to create a private + * (unnamed) Unix domain socket to the child process so we don't have + * to faff around with two file descriptors (a la 'pipe(2)'). + */ + if (socketpair(PF_UNIX, SOCK_STREAM, 0, sv) < 0) { + virReportSystemError(errno, "%s", + _("unable to create socket pair")); + goto error; + } + + if (pipe(errfd) < 0) { + virReportSystemError(errno, "%s", + _("unable to create socket pair")); + goto error; + } + + virCommandSetInputFD(cmd, sv[1]); + virCommandSetOutputFD(cmd, &sv[1]); + virCommandSetErrorFD(cmd, &errfd[1]); + + if (virCommandRunAsync(cmd, &pid) < 0) + goto error; + + /* Parent continues here. */ + VIR_FORCE_CLOSE(sv[1]); + VIR_FORCE_CLOSE(errfd[1]); + + if (!(*retsock = virNetSocketNew(NULL, NULL, true, sv[0], errfd[0], pid))) + goto error; + + virCommandFree(cmd); + + return 0; + +error: + VIR_FORCE_CLOSE(sv[0]); + VIR_FORCE_CLOSE(sv[1]); + VIR_FORCE_CLOSE(errfd[0]); + VIR_FORCE_CLOSE(errfd[1]); + + if (pid > 0) { + kill(pid, SIGTERM); + if (virCommandWait(cmd, NULL) < 0) { + kill(pid, SIGKILL); + if (virCommandWait(cmd, NULL) < 0) { + VIR_WARN("Unable to wait for command %d", pid); + } + } + } + + virCommandFree(cmd); + + return -1; +} +#else +int virNetSocketNewConnectCommand(virCommandPtr cmd ATTRIBUTE_UNUSED, + virNetSocketPtr *retsock ATTRIBUTE_UNUSED) +{ + virReportSystemError(errno, "%s", + _("Tunnelling sockets not supported on this platform")); + return -1; +} +#endif + +int virNetSocketNewConnectSSH(const char *nodename, + const char *service, + const char *binary, + const char *username, + bool noTTY, + const char *netcat, + const char *path, + virNetSocketPtr *retsock) +{ + virCommandPtr cmd; + *retsock = NULL; + + cmd = virCommandNew(binary ? binary : "ssh"); + virCommandAddEnvPassCommon(cmd); + virCommandAddEnvPass(cmd, "SSH_AUTH_SOCK"); + virCommandAddEnvPass(cmd, "SSH_ASKPASS"); + virCommandClearCaps(cmd); + + if (service) + virCommandAddArgList(cmd, "-p", service, NULL); + if (username) + virCommandAddArgList(cmd, "-l", username, NULL); + if (noTTY) + virCommandAddArgList(cmd, "-T", "-o", "BatchMode=yes", + "-e", "none", NULL); + virCommandAddArgList(cmd, nodename, + netcat ? netcat : "nc", + "-U", path, NULL); + + return virNetSocketNewConnectCommand(cmd, retsock); +} + + +int virNetSocketNewConnectExternal(const char **cmdargv, + virNetSocketPtr *retsock) +{ + virCommandPtr cmd; + + *retsock = NULL; + + cmd = virCommandNewArgs(cmdargv); + virCommandAddEnvPassCommon(cmd); + virCommandClearCaps(cmd); + + return virNetSocketNewConnectCommand(cmd, retsock); +} + + +void virNetSocketFree(virNetSocketPtr sock) +{ + if (!sock) + return; + + VIR_DEBUG("sock=%p fd=%d", sock, sock->fd); + if (sock->watch > 0) { + virEventRemoveHandle(sock->watch); + sock->watch = -1; + } + +#ifdef HAVE_SYS_UN_H + /* If a server socket, then unlink UNIX path */ + if (!sock->client && + sock->localAddr.data.sa.sa_family == AF_UNIX && + sock->localAddr.data.un.sun_path[0] != '\0') + unlink(sock->localAddr.data.un.sun_path); +#endif + + VIR_FORCE_CLOSE(sock->fd); + VIR_FORCE_CLOSE(sock->errfd); + +#ifndef WIN32 + if (sock->pid > 0) { + pid_t reap; + kill(sock->pid, SIGTERM); + do { +retry: + reap = waitpid(sock->pid, NULL, 0); + if (reap == -1 && errno == EINTR) + goto retry; + } while (reap != -1 && reap != sock->pid); + } +#endif + + VIR_FREE(sock->localAddrStr); + VIR_FREE(sock->remoteAddrStr); + + VIR_FREE(sock); +} + + +int virNetSocketGetFD(virNetSocketPtr sock) +{ + return sock->fd; +} + + +bool virNetSocketIsLocal(virNetSocketPtr sock) +{ + if (sock->localAddr.data.sa.sa_family == AF_UNIX) + return true; + return false; +} + + +#ifdef SO_PEERCRED +int virNetSocketGetLocalIdentity(virNetSocketPtr sock, + uid_t *uid, + pid_t *pid) +{ + struct ucred cr; + unsigned int cr_len = sizeof (cr); + + if (getsockopt(sock->fd, SOL_SOCKET, SO_PEERCRED, &cr, &cr_len) < 0) { + virReportSystemError(errno, "%s", + _("Failed to get client socket identity")); + return -1; + } + + *pid = cr.pid; + *uid = cr.uid; + return 0; +} +# else +int virNetSocketGetLocalIdentity(virNetSocketPtr sock ATTRIBUTE_UNUSED, + uid_t *uid ATTRIBUTE_UNUSED, + pid_t *pid ATTRIBUTE_UNUSED) +{ + /* XXX Many more OS support UNIX socket credentials we could port to. See dbus ....*/ + virReportSystemError(ENOSYS, "%s", + _("Client socket identity not available")); + return -1; +} +# endif + + +int virNetSocketSetBlocking(virNetSocketPtr sock, + bool blocking) +{ + return virSetBlocking(sock->fd, blocking); +} + + +const char *virNetSocketLocalAddrString(virNetSocketPtr sock) +{ + return sock->localAddrStr; +} + +const char *virNetSocketRemoteAddrString(virNetSocketPtr sock) +{ + return sock->remoteAddrStr; +} + +ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len) +{ + return read(sock->fd, buf, len); +} + +ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len) +{ + return write(sock->fd, buf, len); +} + + +int virNetSocketListen(virNetSocketPtr sock) +{ + if (listen(sock->fd, 30) < 0) { + virReportSystemError(errno, "%s", _("Unable to listen on socket")); + return -1; + } + return 0; +} + +int virNetSocketAccept(virNetSocketPtr sock, virNetSocketPtr *clientsock) +{ + int fd; + virSocketAddr localAddr; + virSocketAddr remoteAddr; + + *clientsock = NULL; + + memset(&localAddr, 0, sizeof(localAddr)); + memset(&remoteAddr, 0, sizeof(remoteAddr)); + + remoteAddr.len = sizeof(remoteAddr.data.stor); + if ((fd = accept(sock->fd, &remoteAddr.data.sa, &remoteAddr.len)) < 0) { + if (errno == ECONNABORTED || + errno == EAGAIN) + return 0; + + virReportSystemError(errno, "%s", + _("Unable to accept client")); + return -1; + } + + localAddr.len = sizeof(localAddr.data); + if (getsockname(fd, &localAddr.data.sa, &localAddr.len) < 0) { + virReportSystemError(errno, "%s", _("Unable to get local socket name")); + VIR_FORCE_CLOSE(fd); + return -1; + } + + if (!(*clientsock = virNetSocketNew(&localAddr, + &remoteAddr, + true, + fd, -1, 0))) { + VIR_FORCE_CLOSE(fd); + return -1; + } + + return 0; +} + + +static void virNetSocketEventHandle(int fd ATTRIBUTE_UNUSED, + int watch ATTRIBUTE_UNUSED, + int events, + void *opaque) +{ + virNetSocketPtr sock = opaque; + + sock->func(sock, events, sock->opaque); +} + +int virNetSocketAddIOCallback(virNetSocketPtr sock, + int events, + virNetSocketIOFunc func, + void *opaque) +{ + if (sock->watch > 0) { + VIR_DEBUG("Watch already registered on socket %p", sock); + return -1; + } + + if ((sock->watch = virEventAddHandle(sock->fd, + events, + virNetSocketEventHandle, + sock, + NULL)) < 0) { + VIR_WARN("Failed to register watch on socket %p", sock); + return -1; + } + sock->func = func; + sock->opaque = opaque; + + return 0; +} + +void virNetSocketUpdateIOCallback(virNetSocketPtr sock, + int events) +{ + if (sock->watch <= 0) { + VIR_DEBUG("Watch not registered on socket %p", sock); + return; + } + + virEventUpdateHandle(sock->watch, events); +} + +void virNetSocketRemoveIOCallback(virNetSocketPtr sock) +{ + if (sock->watch <= 0) { + VIR_DEBUG("Watch not registered on socket %p", sock); + return; + } + + virEventRemoveHandle(sock->watch); + sock->watch = 0; +} diff --git a/src/rpc/virnetsocket.h b/src/rpc/virnetsocket.h new file mode 100644 index 0000000..218fe8f --- /dev/null +++ b/src/rpc/virnetsocket.h @@ -0,0 +1,107 @@ +/* + * virnetsocket.h: generic network socket handling + * + * Copyright (C) 2006-2011 Red Hat, Inc. + * Copyright (C) 2006 Daniel P. Berrange + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#ifndef __VIR_NET_SOCKET_H__ +# define __VIR_NET_SOCKET_H__ + +# include "network.h" +# include "command.h" + +typedef struct _virNetSocket virNetSocket; +typedef virNetSocket *virNetSocketPtr; + + +typedef void (*virNetSocketIOFunc)(virNetSocketPtr sock, + int events, + void *opaque); + + +int virNetSocketNewListenTCP(const char *nodename, + const char *service, + virNetSocketPtr **addrs, + size_t *naddrs); + +int virNetSocketNewListenUNIX(const char *path, + mode_t mask, + gid_t grp, + virNetSocketPtr *addr); + +int virNetSocketNewConnectTCP(const char *nodename, + const char *service, + virNetSocketPtr *addr); + +int virNetSocketNewConnectUNIX(const char *path, + bool spawnDaemon, + const char *binary, + virNetSocketPtr *addr); + +int virNetSocketNewConnectCommand(virCommandPtr cmd, + virNetSocketPtr *retsock); + +int virNetSocketNewConnectSSH(const char *nodename, + const char *service, + const char *binary, + const char *username, + bool noTTY, + const char *netcat, + const char *path, + virNetSocketPtr *addr); + +int virNetSocketNewConnectExternal(const char **cmdargv, + virNetSocketPtr *addr); + +int virNetSocketGetFD(virNetSocketPtr sock); +bool virNetSocketIsLocal(virNetSocketPtr sock); + +int virNetSocketGetLocalIdentity(virNetSocketPtr sock, + uid_t *uid, + pid_t *pid); + +int virNetSocketSetBlocking(virNetSocketPtr sock, + bool blocking); + +ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len); +ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len); + +void virNetSocketFree(virNetSocketPtr sock); + +const char *virNetSocketLocalAddrString(virNetSocketPtr sock); +const char *virNetSocketRemoteAddrString(virNetSocketPtr sock); + +int virNetSocketListen(virNetSocketPtr sock); +int virNetSocketAccept(virNetSocketPtr sock, + virNetSocketPtr *clientsock); + +int virNetSocketAddIOCallback(virNetSocketPtr sock, + int events, + virNetSocketIOFunc func, + void *opaque); + +void virNetSocketUpdateIOCallback(virNetSocketPtr sock, + int events); + +void virNetSocketRemoveIOCallback(virNetSocketPtr sock); + + + +#endif /* __VIR_NET_SOCKET_H__ */ diff --git a/tests/.gitignore b/tests/.gitignore index 36115ea..7f26dd7 100644 --- a/tests/.gitignore +++ b/tests/.gitignore @@ -1,6 +1,7 @@ *.exe .deps .libs +ssh commandhelper commandhelper.log commandhelper.pid @@ -31,6 +32,7 @@ storagepoolxml2xmltest storagevolxml2xmltest virbuftest virnetmessagetest +virnetsockettest virshtest vmx2xmltest xencapstest diff --git a/tests/Makefile.am b/tests/Makefile.am index f80b98f..071fe6f 100644 --- a/tests/Makefile.am +++ b/tests/Makefile.am @@ -78,7 +78,11 @@ EXTRA_DIST = \ check_PROGRAMS = virshtest conftest sockettest \ nodeinfotest qparamtest virbuftest \ commandtest commandhelper seclabeltest \ - hashtest virnetmessagetest + hashtest virnetmessagetest virnetsockettest ssh + +# This is a fake SSH we use from virnetsockettest +ssh_SOURCES = ssh.c +ssh_LDADD = $(COVERAGE_LDFLAGS) if WITH_XEN check_PROGRAMS += xml2sexprtest sexpr2xmltest \ @@ -181,6 +185,7 @@ TESTS = virshtest \ seclabeltest \ hashtest \ virnetmessagetest \ + virnetsockettest \ $(test_scripts) if WITH_XEN @@ -388,6 +393,11 @@ virnetmessagetest_SOURCES = \ virnetmessagetest_CFLAGS = -Dabs_builddir="\"$(abs_builddir)\"" virnetmessagetest_LDADD = $(LDADDS) +virnetsockettest_SOURCES = \ + virnetsockettest.c testutils.h testutils.c +virnetsockettest_CFLAGS = -Dabs_builddir="\"$(abs_builddir)\"" +virnetsockettest_LDADD = $(LDADDS) + seclabeltest_SOURCES = \ seclabeltest.c diff --git a/tests/ssh.c b/tests/ssh.c new file mode 100644 index 0000000..08bb63d --- /dev/null +++ b/tests/ssh.c @@ -0,0 +1,54 @@ +/* + * Copyright (C) 2011 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#include <config.h> + +#include <stdio.h> +#include "internal.h" + +int main(int argc, char **argv) +{ + int i; + int failConnect = 0; /* Exit -1, with no data on stdout, msg on stderr */ + int dieEarly = 0; /* Exit -1, with partial data on stdout, msg on stderr */ + + for (i = 1 ; i < argc ; i++) { + if (STREQ(argv[i], "nosuchhost")) + failConnect = 1; + else if (STREQ(argv[i], "crashinghost")) + dieEarly = 1; + } + + if (failConnect) { + fprintf(stderr, "%s", "Cannot connect to host nosuchhost\n"); + return -1; + } + + if (dieEarly) { + printf("%s\n", "Hello World"); + fprintf(stderr, "%s", "Hangup from host\n"); + return -1; + } + + for (i = 1 ; i < argc ; i++) + printf("%s%c", argv[i], i == (argc -1) ? '\n' : ' '); + + return 0; +} diff --git a/tests/testutils.c b/tests/testutils.c index d87347d..64d4ec0 100644 --- a/tests/testutils.c +++ b/tests/testutils.c @@ -569,9 +569,11 @@ int virtTestMain(int argc, return 1; virLogSetFromEnv(); - if (virLogDefineOutput(virtTestLogOutput, virtTestLogClose, &testLog, - 0, 0, NULL, 0) < 0) - return 1; + if (!getenv("LIBVIRT_DEBUG") && !virLogGetNbOutputs()) { + if (virLogDefineOutput(virtTestLogOutput, virtTestLogClose, &testLog, + 0, 0, NULL, 0) < 0) + return 1; + } #if TEST_OOM if ((oomStr = getenv("VIR_TEST_OOM")) != NULL) { diff --git a/tests/virnetsockettest.c b/tests/virnetsockettest.c new file mode 100644 index 0000000..65dd5d0 --- /dev/null +++ b/tests/virnetsockettest.c @@ -0,0 +1,522 @@ +/* + * Copyright (C) 2011 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#include <config.h> + +#include <stdlib.h> +#include <signal.h> +#ifdef HAVE_IFADDRS_H +#include <ifaddrs.h> +#endif + +#include "testutils.h" +#include "util.h" +#include "virterror_internal.h" +#include "memory.h" +#include "logging.h" +#include "ignore-value.h" +#include "files.h" + +#include "rpc/virnetsocket.h" + +#define VIR_FROM_THIS VIR_FROM_RPC + +#if HAVE_IFADDRS_H +#define BASE_PORT 5672 + +static int +checkProtocols(bool *hasIPv4, bool *hasIPv6, + int *freePort) +{ + struct ifaddrs *ifaddr = NULL, *ifa; + struct sockaddr_in in4; + struct sockaddr_in6 in6; + int s4 = -1, s6 = -1; + int i; + int ret = -1; + + *hasIPv4 = *hasIPv6 = false; + *freePort = 0; + + if (getifaddrs(&ifaddr) < 0) + goto cleanup; + + for (ifa = ifaddr; ifa != NULL; ifa = ifa->ifa_next) { + if (!ifa->ifa_addr) + continue; + + if (ifa->ifa_addr->sa_family == AF_INET) + *hasIPv4 = true; + if (ifa->ifa_addr->sa_family == AF_INET6) + *hasIPv6 = true; + } + + VIR_DEBUG("Protocols: v4 %d v6 %d\n", *hasIPv4, *hasIPv6); + + freeifaddrs(ifaddr); + + for (i = 0 ; i < 50 ; i++) { + int only = 1; + if ((s4 = socket(AF_INET, SOCK_STREAM, 0)) < 0) + goto cleanup; + + if ((s6 = socket(AF_INET6, SOCK_STREAM, 0)) < 0) + goto cleanup; + + if (setsockopt(s6, IPPROTO_IPV6, IPV6_V6ONLY, &only, sizeof(only)) < 0) + goto cleanup; + + memset(&in4, 0, sizeof(in4)); + memset(&in6, 0, sizeof(in6)); + + in4.sin_family = AF_INET; + in4.sin_port = htons(BASE_PORT + i); + in4.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + in6.sin6_family = AF_INET6; + in6.sin6_port = htons(BASE_PORT + i); + in6.sin6_addr = in6addr_loopback; + + if (bind(s4, (struct sockaddr *)&in4, sizeof(in4)) < 0) { + if (errno == EADDRINUSE) { + VIR_FORCE_CLOSE(s4); + VIR_FORCE_CLOSE(s6); + continue; + } + goto cleanup; + } + if (bind(s6, (struct sockaddr *)&in6, sizeof(in6)) < 0) { + if (errno == EADDRINUSE) { + VIR_FORCE_CLOSE(s4); + VIR_FORCE_CLOSE(s6); + continue; + } + goto cleanup; + } + + *freePort = BASE_PORT + i; + break; + } + + VIR_DEBUG("Choose port %d\n", *freePort); + + ret = 0; + +cleanup: + VIR_FORCE_CLOSE(s4); + VIR_FORCE_CLOSE(s6); + return ret; +} + + +struct testTCPData { + const char *lnode; + int port; + const char *cnode; +}; + +static int testSocketTCPAccept(const void *opaque) +{ + virNetSocketPtr *lsock = NULL; /* Listen socket */ + size_t nlsock = 0, i; + virNetSocketPtr ssock = NULL; /* Server socket */ + virNetSocketPtr csock = NULL; /* Client socket */ + const struct testTCPData *data = opaque; + int ret = -1; + char portstr[100]; + + snprintf(portstr, sizeof(portstr), "%d", data->port); + + if (virNetSocketNewListenTCP(data->lnode, portstr, &lsock, &nlsock) < 0) + goto cleanup; + + for (i = 0 ; i < nlsock ; i++) { + if (virNetSocketListen(lsock[i]) < 0) + goto cleanup; + } + + if (virNetSocketNewConnectTCP(data->cnode, portstr, &csock) < 0) + goto cleanup; + + virNetSocketFree(csock); + + for (i = 0 ; i < nlsock ; i++) { + if (virNetSocketAccept(lsock[i], &ssock) != -1 && ssock) { + char c = 'a'; + if (virNetSocketWrite(ssock, &c, 1) != -1 && + virNetSocketRead(ssock, &c, 1) != -1) { + VIR_DEBUG("Unexpected client socket present"); + goto cleanup; + } + } + virNetSocketFree(ssock); + ssock = NULL; + } + + ret = 0; + +cleanup: + virNetSocketFree(ssock); + for (i = 0 ; i < nlsock ; i++) + virNetSocketFree(lsock[i]); + VIR_FREE(lsock); + return ret; +} +#endif + + +#ifndef WIN32 +static int testSocketUNIXAccept(const void *data ATTRIBUTE_UNUSED) +{ + virNetSocketPtr lsock = NULL; /* Listen socket */ + virNetSocketPtr ssock = NULL; /* Server socket */ + virNetSocketPtr csock = NULL; /* Client socket */ + int ret = -1; + + char *path; + if (progname[0] == '/') { + if (virAsprintf(&path, "%s-test.sock", progname) < 0) { + virReportOOMError(); + goto cleanup; + } + } else { + if (virAsprintf(&path, "%s/%s-test.sock", abs_builddir, progname) < 0) { + virReportOOMError(); + goto cleanup; + } + } + + if (virNetSocketNewListenUNIX(path, 0700, getgid(), &lsock) < 0) + goto cleanup; + + if (virNetSocketListen(lsock) < 0) + goto cleanup; + + if (virNetSocketNewConnectUNIX(path, false, NULL, &csock) < 0) + goto cleanup; + + virNetSocketFree(csock); + + if (virNetSocketAccept(lsock, &ssock) != -1) { + char c = 'a'; + if (virNetSocketWrite(ssock, &c, 1) != -1) { + VIR_DEBUG("Unexpected client socket present"); + goto cleanup; + } + } + + ret = 0; + +cleanup: + VIR_FREE(path); + virNetSocketFree(lsock); + virNetSocketFree(ssock); + return ret; +} + + +static int testSocketUNIXAddrs(const void *data ATTRIBUTE_UNUSED) +{ + virNetSocketPtr lsock = NULL; /* Listen socket */ + virNetSocketPtr ssock = NULL; /* Server socket */ + virNetSocketPtr csock = NULL; /* Client socket */ + int ret = -1; + + char *path; + if (progname[0] == '/') { + if (virAsprintf(&path, "%s-test.sock", progname) < 0) { + virReportOOMError(); + goto cleanup; + } + } else { + if (virAsprintf(&path, "%s/%s-test.sock", abs_builddir, progname) < 0) { + virReportOOMError(); + goto cleanup; + } + } + + if (virNetSocketNewListenUNIX(path, 0700, getgid(), &lsock) < 0) + goto cleanup; + + if (STRNEQ(virNetSocketLocalAddrString(lsock), "127.0.0.1;0")) { + VIR_DEBUG("Unexpected local address"); + goto cleanup; + } + + if (virNetSocketRemoteAddrString(lsock) != NULL) { + VIR_DEBUG("Unexpected remote address"); + goto cleanup; + } + + if (virNetSocketListen(lsock) < 0) + goto cleanup; + + if (virNetSocketNewConnectUNIX(path, false, NULL, &csock) < 0) + goto cleanup; + + if (STRNEQ(virNetSocketLocalAddrString(csock), "127.0.0.1;0")) { + VIR_DEBUG("Unexpected local address"); + goto cleanup; + } + + if (STRNEQ(virNetSocketRemoteAddrString(csock), "127.0.0.1;0")) { + VIR_DEBUG("Unexpected local address"); + goto cleanup; + } + + + if (virNetSocketAccept(lsock, &ssock) < 0) { + VIR_DEBUG("Unexpected client socket missing"); + goto cleanup; + } + + + if (STRNEQ(virNetSocketLocalAddrString(ssock), "127.0.0.1;0")) { + VIR_DEBUG("Unexpected local address"); + goto cleanup; + } + + if (STRNEQ(virNetSocketRemoteAddrString(ssock), "127.0.0.1;0")) { + VIR_DEBUG("Unexpected local address"); + goto cleanup; + } + + + ret = 0; + +cleanup: + VIR_FREE(path); + virNetSocketFree(lsock); + virNetSocketFree(ssock); + virNetSocketFree(csock); + return ret; +} + +static int testSocketCommandNormal(const void *data ATTRIBUTE_UNUSED) +{ + virNetSocketPtr csock = NULL; /* Client socket */ + char buf[100]; + size_t i; + int ret = -1; + virCommandPtr cmd = virCommandNewArgList("/bin/cat", "/dev/zero", NULL); + virCommandAddEnvPassCommon(cmd); + + if (virNetSocketNewConnectCommand(cmd, &csock) < 0) + goto cleanup; + + virNetSocketSetBlocking(csock, true); + + if (virNetSocketRead(csock, buf, sizeof(buf)) < 0) + goto cleanup; + + for (i = 0 ; i < sizeof(buf) ; i++) + if (buf[i] != '\0') + goto cleanup; + + ret = 0; + +cleanup: + virNetSocketFree(csock); + return ret; +} + +static int testSocketCommandFail(const void *data ATTRIBUTE_UNUSED) +{ + virNetSocketPtr csock = NULL; /* Client socket */ + char buf[100]; + int ret = -1; + virCommandPtr cmd = virCommandNewArgList("/bin/cat", "/dev/does-not-exist", NULL); + virCommandAddEnvPassCommon(cmd); + + if (virNetSocketNewConnectCommand(cmd, &csock) < 0) + goto cleanup; + + virNetSocketSetBlocking(csock, true); + + if (virNetSocketRead(csock, buf, sizeof(buf)) == 0) + goto cleanup; + + ret = 0; + +cleanup: + virNetSocketFree(csock); + return ret; +} + +struct testSSHData { + const char *nodename; + const char *service; + const char *binary; + const char *username; + bool noTTY; + const char *netcat; + const char *path; + + const char *expectOut; + bool failConnect; + bool dieEarly; +}; + +static int testSocketSSH(const void *opaque) +{ + const struct testSSHData *data = opaque; + virNetSocketPtr csock = NULL; /* Client socket */ + int ret = -1; + char buf[1024]; + + if (virNetSocketNewConnectSSH(data->nodename, + data->service, + data->binary, + data->username, + data->noTTY, + data->netcat, + data->path, + &csock) < 0) + goto cleanup; + + virNetSocketSetBlocking(csock, true); + + if (data->failConnect) { + if (virNetSocketRead(csock, buf, sizeof(buf)-1) >= 0) { + VIR_DEBUG("Expected connect failure, but got some socket data"); + goto cleanup; + } + } else { + ssize_t rv; + if ((rv = virNetSocketRead(csock, buf, sizeof(buf)-1)) < 0) { + VIR_DEBUG("Didn't get any socket data"); + goto cleanup; + } + buf[rv] = '\0'; + + if (!STREQ(buf, data->expectOut)) { + virtTestDifference(stderr, data->expectOut, buf); + goto cleanup; + } + + if (data->dieEarly && + virNetSocketRead(csock, buf, sizeof(buf)-1) >= 0) { + VIR_DEBUG("Got too much socket data"); + goto cleanup; + } + } + + ret = 0; + +cleanup: + virNetSocketFree(csock); + return ret; +} + +#endif + + +static int +mymain(void) +{ + int ret = 0; +#ifdef HAVE_IFADDRS_H + bool hasIPv4, hasIPv6; + int freePort; +#endif + + signal(SIGPIPE, SIG_IGN); + +#ifdef HAVE_IFADDRS_H + if (checkProtocols(&hasIPv4, &hasIPv6, &freePort) < 0) { + fprintf(stderr, "Cannot identify IPv4/6 availability\n"); + return (EXIT_FAILURE); + } + + if (hasIPv4) { + struct testTCPData tcpData = { "127.0.0.1", freePort, "127.0.0.1" }; + if (virtTestRun("Socket TCP/IPv4 Accept", 1, testSocketTCPAccept, &tcpData) < 0) + ret = -1; + } + if (hasIPv6) { + struct testTCPData tcpData = { "::1", freePort, "::1" }; + if (virtTestRun("Socket TCP/IPv6 Accept", 1, testSocketTCPAccept, &tcpData) < 0) + ret = -1; + } + if (hasIPv6 && hasIPv4) { + struct testTCPData tcpData = { NULL, freePort, "127.0.0.1" }; + if (virtTestRun("Socket TCP/IPv4+IPv6 Accept", 1, testSocketTCPAccept, &tcpData) < 0) + ret = -1; + + tcpData.cnode = "::1"; + if (virtTestRun("Socket TCP/IPv4+IPv6 Accept", 1, testSocketTCPAccept, &tcpData) < 0) + ret = -1; + } +#endif + +#ifndef WIN32 + if (virtTestRun("Socket UNIX Accept", 1, testSocketUNIXAccept, NULL) < 0) + ret = -1; + + if (virtTestRun("Socket UNIX Addrs", 1, testSocketUNIXAddrs, NULL) < 0) + ret = -1; + + if (virtTestRun("Socket External Command /dev/zero", 1, testSocketCommandNormal, NULL) < 0) + ret = -1; + if (virtTestRun("Socket External Command /dev/does-not-exist", 1, testSocketCommandFail, NULL) < 0) + ret = -1; + + struct testSSHData sshData1 = { + .nodename = "somehost", + .path = "/tmp/socket", + .expectOut = "somehost nc -U /tmp/socket\n", + }; + if (virtTestRun("SSH test 1", 1, testSocketSSH, &sshData1) < 0) + ret = -1; + + struct testSSHData sshData2 = { + .nodename = "somehost", + .service = "9000", + .username = "fred", + .netcat = "netcat", + .noTTY = true, + .path = "/tmp/socket", + .expectOut = "-p 9000 -l fred -T -o BatchMode=yes -e none somehost netcat -U /tmp/socket\n", + }; + if (virtTestRun("SSH test 2", 1, testSocketSSH, &sshData2) < 0) + ret = -1; + + struct testSSHData sshData3 = { + .nodename = "nosuchhost", + .path = "/tmp/socket", + .failConnect = true, + }; + if (virtTestRun("SSH test 3", 1, testSocketSSH, &sshData3) < 0) + ret = -1; + + struct testSSHData sshData4 = { + .nodename = "crashyhost", + .path = "/tmp/socket", + .expectOut = "crashyhost nc -U /tmp/socket\n", + .dieEarly = true, + }; + if (virtTestRun("SSH test 4", 1, testSocketSSH, &sshData4) < 0) + ret = -1; + +#endif + + return (ret==0 ? EXIT_SUCCESS : EXIT_FAILURE); +} + +VIRT_TEST_MAIN(mymain) -- 1.7.4.4

This provides two modules for handling TLS * virNetTLSContext provides the process-wide state, in particular all the x509 credentials, DH params and x509 whitelists * virNetTLSSession provides the per-connection state, ie the TLS session itself. The virNetTLSContext provides APIs for validating a TLS session's x509 credentials. The virNetTLSSession includes APIs for performing the initial TLS handshake and sending/recving encrypted data * src/Makefile.am: Add to libvirt-net-rpc.la * src/rpc/virnettlscontext.c, src/rpc/virnettlscontext.h: Generic TLS handling code --- cfg.mk | 1 + po/POTFILES.in | 1 + src/Makefile.am | 5 +- src/rpc/virnettlscontext.c | 894 ++++++++++++++++++++++++++++++++++++++++++++ src/rpc/virnettlscontext.h | 99 +++++ 5 files changed, 999 insertions(+), 1 deletions(-) create mode 100644 src/rpc/virnettlscontext.c create mode 100644 src/rpc/virnettlscontext.h diff --git a/cfg.mk b/cfg.mk index c7d6b51..2011968 100644 --- a/cfg.mk +++ b/cfg.mk @@ -127,6 +127,7 @@ useless_free_options = \ --name=virLastErrFreeData \ --name=virNetMessageFree \ --name=virNetSocketFree \ + --name=virNetTLSSessionFree \ --name=virNWFilterDefFree \ --name=virNWFilterEntryFree \ --name=virNWFilterHashTableFree \ diff --git a/po/POTFILES.in b/po/POTFILES.in index adb2bbd..55cb816 100644 --- a/po/POTFILES.in +++ b/po/POTFILES.in @@ -69,6 +69,7 @@ src/remote/remote_client_bodies.h src/remote/remote_driver.c src/rpc/virnetmessage.c src/rpc/virnetsocket.c +src/rpc/virnettlscontext.c src/secret/secret_driver.c src/security/security_apparmor.c src/security/security_dac.c diff --git a/src/Makefile.am b/src/Makefile.am index 639a41e..20ea2fb 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -1159,10 +1159,13 @@ noinst_LTLIBRARIES += libvirt-net-rpc.la libvirt_net_rpc_la_SOURCES = \ rpc/virnetmessage.h rpc/virnetmessage.c \ rpc/virnetprotocol.h rpc/virnetprotocol.c \ - rpc/virnetsocket.h rpc/virnetsocket.c + rpc/virnetsocket.h rpc/virnetsocket.c \ + rpc/virnettlscontext.h rpc/virnettlscontext.c libvirt_net_rpc_la_CFLAGS = \ + $(GNUTLS_CFLAGS) \ $(AM_CFLAGS) libvirt_net_rpc_la_LDFLAGS = \ + $(GNUTLS_LIBS) \ $(AM_LDFLAGS) \ $(CYGWIN_EXTRA_LDFLAGS) \ $(MINGW_EXTRA_LDFLAGS) diff --git a/src/rpc/virnettlscontext.c b/src/rpc/virnettlscontext.c new file mode 100644 index 0000000..db9ce31 --- /dev/null +++ b/src/rpc/virnettlscontext.c @@ -0,0 +1,894 @@ +/* + * virnettlscontext.c: TLS encryption/x509 handling + * + * Copyright (C) 2010-2011 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + */ + +#include <config.h> + +#include <unistd.h> +#include <fnmatch.h> +#include <stdlib.h> + +#include <gnutls/gnutls.h> +#include <gnutls/x509.h> +#include "gnutls_1_0_compat.h" + +#include "virnettlscontext.h" + +#include "memory.h" +#include "virterror_internal.h" +#include "util.h" +#include "logging.h" +#include "configmake.h" + +#define DH_BITS 1024 + +#define LIBVIRT_PKI_DIR SYSCONFDIR "/pki" +#define LIBVIRT_CACERT LIBVIRT_PKI_DIR "/CA/cacert.pem" +#define LIBVIRT_CACRL LIBVIRT_PKI_DIR "/CA/cacrl.pem" +#define LIBVIRT_CLIENTKEY LIBVIRT_PKI_DIR "/libvirt/private/clientkey.pem" +#define LIBVIRT_CLIENTCERT LIBVIRT_PKI_DIR "/libvirt/clientcert.pem" +#define LIBVIRT_SERVERKEY LIBVIRT_PKI_DIR "/libvirt/private/serverkey.pem" +#define LIBVIRT_SERVERCERT LIBVIRT_PKI_DIR "/libvirt/servercert.pem" + +#define VIR_FROM_THIS VIR_FROM_RPC +#define virNetError(code, ...) \ + virReportErrorHelper(VIR_FROM_THIS, code, __FILE__, \ + __FUNCTION__, __LINE__, __VA_ARGS__) + +struct _virNetTLSContext { + int refs; + + gnutls_certificate_credentials_t x509cred; + gnutls_dh_params_t dhParams; + + bool isServer; + bool requireValidCert; + const char *const*x509dnWhitelist; +}; + +struct _virNetTLSSession { + int refs; + + bool handshakeComplete; + + char *hostname; + gnutls_session_t session; + virNetTLSSessionWriteFunc writeFunc; + virNetTLSSessionReadFunc readFunc; + void *opaque; +}; + + +static int +virNetTLSContextCheckCertFile(const char *type, const char *file, bool allowMissing) +{ + if (!virFileExists(file)) { + if (allowMissing) + return 1; + + virReportSystemError(errno, + _("Cannot read %s '%s'"), + type, file); + return -1; + } + return 0; +} + + +static void virNetTLSLog(int level, const char *str) { + VIR_DEBUG("%d %s", level, str); +} + +static int virNetTLSContextLoadCredentials(virNetTLSContextPtr ctxt, + bool isServer, + const char *cacert, + const char *cacrl, + const char *cert, + const char *key) +{ + int ret = -1; + int err; + + if (cacert && cacert[0] != '\0') { + if (virNetTLSContextCheckCertFile("CA certificate", cacert, false) < 0) + goto cleanup; + + VIR_DEBUG("loading CA cert from %s", cacert); + err = gnutls_certificate_set_x509_trust_file(ctxt->x509cred, + cacert, + GNUTLS_X509_FMT_PEM); + if (err < 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Unable to set x509 CA certificate: %s: %s"), + cacert, gnutls_strerror (err)); + goto cleanup; + } + } + + if (cacrl && cacrl[0] != '\0') { + int rv; + if ((rv = virNetTLSContextCheckCertFile("CA revocation list", cacrl, true)) < 0) + goto cleanup; + + if (rv == 0) { + VIR_DEBUG("loading CRL from %s", cacrl); + err = gnutls_certificate_set_x509_crl_file(ctxt->x509cred, + cacrl, + GNUTLS_X509_FMT_PEM); + if (err < 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Unable to set x509 certificate revocation list: %s: %s"), + cacrl, gnutls_strerror(err)); + goto cleanup; + } + } else { + VIR_DEBUG("Skipping non-existent CA CRL %s", cacrl); + } + } + + if (cert && cert[0] != '\0' && key && key[0] != '\0') { + int rv; + if ((rv = virNetTLSContextCheckCertFile("certificate", cert, !isServer)) < 0) + goto cleanup; + if (rv == 0 && + (rv = virNetTLSContextCheckCertFile("private key", key, !isServer)) < 0) + goto cleanup; + + if (rv == 0) { + VIR_DEBUG("loading cert and key from %s and %s", cert, key); + err = + gnutls_certificate_set_x509_key_file(ctxt->x509cred, + cert, key, + GNUTLS_X509_FMT_PEM); + if (err < 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Unable to set x509 key and certificate: %s, %s: %s"), + key, cert, gnutls_strerror(err)); + goto cleanup; + } + } else { + VIR_DEBUG("Skipping non-existant cert %s key %s on client", cert, key); + } + } + + ret = 0; + +cleanup: + return ret; +} + + +static virNetTLSContextPtr virNetTLSContextNew(const char *cacert, + const char *cacrl, + const char *cert, + const char *key, + const char *const*x509dnWhitelist, + bool requireValidCert, + bool isServer) +{ + virNetTLSContextPtr ctxt; + char *gnutlsdebug; + int err; + + VIR_DEBUG("cacert=%s cacrl=%s cert=%s key=%s requireValid=%d isServer=%d", + cacert, NULLSTR(cacrl), cert, key, requireValidCert, isServer); + + if (VIR_ALLOC(ctxt) < 0) { + virReportOOMError(); + return NULL; + } + + ctxt->refs = 1; + + /* Initialise GnuTLS. */ + gnutls_global_init(); + + if ((gnutlsdebug = getenv("LIBVIRT_GNUTLS_DEBUG")) != NULL) { + int val; + if (virStrToLong_i(gnutlsdebug, NULL, 10, &val) < 0) + val = 10; + gnutls_global_set_log_level(val); + gnutls_global_set_log_function(virNetTLSLog); + VIR_DEBUG("Enabled GNUTLS debug"); + } + + + err = gnutls_certificate_allocate_credentials(&ctxt->x509cred); + if (err) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Unable to allocate x509 credentials: %s"), + gnutls_strerror(err)); + goto error; + } + + if (virNetTLSContextLoadCredentials(ctxt, isServer, cacert, cacrl, cert, key) < 0) + goto error; + + /* Generate Diffie Hellman parameters - for use with DHE + * kx algorithms. These should be discarded and regenerated + * once a day, once a week or once a month. Depending on the + * security requirements. + */ + if (isServer) { + err = gnutls_dh_params_init(&ctxt->dhParams); + if (err < 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Unable to initialize diffie-hellman parameters: %s"), + gnutls_strerror(err)); + goto error; + } + err = gnutls_dh_params_generate2(ctxt->dhParams, DH_BITS); + if (err < 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Unable to generate diffie-hellman parameters: %s"), + gnutls_strerror(err)); + goto error; + } + + gnutls_certificate_set_dh_params(ctxt->x509cred, + ctxt->dhParams); + } + + ctxt->requireValidCert = requireValidCert; + ctxt->x509dnWhitelist = x509dnWhitelist; + ctxt->isServer = isServer; + + return ctxt; + +error: + if (isServer) + gnutls_dh_params_deinit(ctxt->dhParams); + gnutls_certificate_free_credentials(ctxt->x509cred); + VIR_FREE(ctxt); + return NULL; +} + + +static int virNetTLSContextLocateCredentials(const char *pkipath, + bool tryUserPkiPath, + bool isServer, + char **cacert, + char **cacrl, + char **cert, + char **key) +{ + char *userdir = NULL; + char *user_pki_path = NULL; + + *cacert = NULL; + *cacrl = NULL; + *key = NULL; + *cert = NULL; + + VIR_DEBUG("pkipath=%s isServer=%d tryUserPkiPath=%d", + pkipath, isServer, tryUserPkiPath); + + /* Explicit path, then use that no matter whether the + * files actually exist there + */ + if (pkipath) { + VIR_DEBUG("Told to use TLS credentials in %s", pkipath); + if ((virAsprintf(cacert, "%s/%s", pkipath, + "cacert.pem")) < 0) + goto out_of_memory; + if ((virAsprintf(cacrl, "%s/%s", pkipath, + "cacrl.pem")) < 0) + goto out_of_memory; + if ((virAsprintf(key, "%s/%s", pkipath, + isServer ? "serverkey.pem" : "clientkey.pem")) < 0) + goto out_of_memory; + + if ((virAsprintf(cert, "%s/%s", pkipath, + isServer ? "servercert.pem" : "clientcert.pem")) < 0) + goto out_of_memory; + } else if (tryUserPkiPath) { + /* Check to see if $HOME/.pki contains at least one of the + * files and if so, use that + */ + userdir = virGetUserDirectory(getuid()); + + if (!userdir) + goto out_of_memory; + + if (virAsprintf(&user_pki_path, "%s/.pki/libvirt", userdir) < 0) + goto out_of_memory; + + VIR_DEBUG("Trying to find TLS user credentials in %s", user_pki_path); + + if ((virAsprintf(cacert, "%s/%s", user_pki_path, + "cacert.pem")) < 0) + goto out_of_memory; + + if ((virAsprintf(cacrl, "%s/%s", user_pki_path, + "cacrl.pem")) < 0) + goto out_of_memory; + + if ((virAsprintf(key, "%s/%s", user_pki_path, + isServer ? "serverkey.pem" : "clientkey.pem")) < 0) + goto out_of_memory; + + if ((virAsprintf(cert, "%s/%s", user_pki_path, + isServer ? "servercert.pem" : "clientcert.pem")) < 0) + goto out_of_memory; + + /* + * If some of the files can't be found, fallback + * to the global location for them + */ + if (!virFileExists(*cacert)) + VIR_FREE(*cacert); + if (!virFileExists(*cacrl)) + VIR_FREE(*cacrl); + + /* Check these as a pair, since it they are + * mutually dependent + */ + if (!virFileExists(*key) || !virFileExists(*cert)) { + VIR_FREE(*key); + VIR_FREE(*cert); + } + } + + /* No explicit path, or user path didn't exist, so + * fallback to global defaults + */ + if (!*cacert) { + VIR_DEBUG("Using default TLS CA certificate path"); + if (!(*cacert = strdup(LIBVIRT_CACERT))) + goto out_of_memory; + } + + if (!*cacrl) { + VIR_DEBUG("Using default TLS CA revocation list path"); + if (!(*cacrl = strdup(LIBVIRT_CACRL))) + goto out_of_memory; + } + + if (!*key && !*cert) { + VIR_DEBUG("Using default TLS key/certificate path"); + if (!(*key = strdup(isServer ? LIBVIRT_SERVERKEY : LIBVIRT_CLIENTKEY))) + goto out_of_memory; + + if (!(*cert = strdup(isServer ? LIBVIRT_SERVERCERT : LIBVIRT_CLIENTCERT))) + goto out_of_memory; + } + + VIR_FREE(user_pki_path); + VIR_FREE(userdir); + + return 0; + +out_of_memory: + virReportOOMError(); + VIR_FREE(*cacert); + VIR_FREE(*cacrl); + VIR_FREE(*key); + VIR_FREE(*cert); + VIR_FREE(user_pki_path); + VIR_FREE(userdir); + return -1; +} + + +static virNetTLSContextPtr virNetTLSContextNewPath(const char *pkipath, + bool tryUserPkiPath, + const char *const*x509dnWhitelist, + bool requireValidCert, + bool isServer) +{ + char *cacert = NULL, *cacrl = NULL, *key = NULL, *cert = NULL; + virNetTLSContextPtr ctxt = NULL; + + if (virNetTLSContextLocateCredentials(pkipath, tryUserPkiPath, isServer, + &cacert, &cacrl, &key, &cert) < 0) + return NULL; + + ctxt = virNetTLSContextNew(cacert, cacrl, key, cert, + x509dnWhitelist, requireValidCert, isServer); + + VIR_FREE(cacert); + VIR_FREE(cacrl); + VIR_FREE(key); + VIR_FREE(cert); + + return ctxt; +} + +virNetTLSContextPtr virNetTLSContextNewServerPath(const char *pkipath, + bool tryUserPkiPath, + const char *const*x509dnWhitelist, + bool requireValidCert) +{ + return virNetTLSContextNewPath(pkipath, tryUserPkiPath, + x509dnWhitelist, requireValidCert, true); +} + +virNetTLSContextPtr virNetTLSContextNewClientPath(const char *pkipath, + bool tryUserPkiPath, + bool requireValidCert) +{ + return virNetTLSContextNewPath(pkipath, tryUserPkiPath, + NULL, requireValidCert, false); +} + + +virNetTLSContextPtr virNetTLSContextNewServer(const char *cacert, + const char *cacrl, + const char *cert, + const char *key, + const char *const*x509dnWhitelist, + bool requireValidCert) +{ + return virNetTLSContextNew(cacert, cacrl, key, cert, + x509dnWhitelist, requireValidCert, true); +} + + +virNetTLSContextPtr virNetTLSContextNewClient(const char *cacert, + const char *cacrl, + const char *cert, + const char *key, + bool requireValidCert) +{ + return virNetTLSContextNew(cacert, cacrl, key, cert, + NULL, requireValidCert, false); +} + + +void virNetTLSContextRef(virNetTLSContextPtr ctxt) +{ + ctxt->refs++; +} + + +/* Check DN is on tls_allowed_dn_list. */ +static int +virNetTLSContextCheckDN(virNetTLSContextPtr ctxt, + const char *dname) +{ + const char *const*wildcards; + + /* If the list is not set, allow any DN. */ + wildcards = ctxt->x509dnWhitelist; + if (!wildcards) + return 1; + + while (*wildcards) { + int ret = fnmatch (*wildcards, dname, 0); + if (ret == 0) /* Succesful match */ + return 1; + if (ret != FNM_NOMATCH) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("Malformed TLS whitelist regular expression '%s'"), + *wildcards); + return -1; + } + + wildcards++; + } + + /* Log the client's DN for debugging */ + VIR_DEBUG(_("Failed whitelist check for client DN '%s'"), dname); + + /* This is the most common error: make it informative. */ + virNetError(VIR_ERR_SYSTEM_ERROR, "%s", + _("Client's Distinguished Name is not on the list " + "of allowed clients (tls_allowed_dn_list). Use " + "'certtool -i --infile clientcert.pem' to view the" + "Distinguished Name field in the client certificate," + "or run this daemon with --verbose option.")); + return 0; +} + +static int virNetTLSContextValidCertificate(virNetTLSContextPtr ctxt, + virNetTLSSessionPtr sess) +{ + int ret; + unsigned int status; + const gnutls_datum_t *certs; + unsigned int nCerts, i; + time_t now; + char name[256]; + size_t namesize = sizeof name; + + memset(name, 0, namesize); + + if ((ret = gnutls_certificate_verify_peers2(sess->session, &status)) < 0){ + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Unable to verify TLS peer: %s"), + gnutls_strerror(ret)); + goto authdeny; + } + + if ((now = time(NULL)) == ((time_t)-1)) { + virReportSystemError(errno, "%s", + _("cannot get current time")); + goto authfail; + } + + if (status != 0) { + const char *reason = _("Invalid certificate"); + + if (status & GNUTLS_CERT_INVALID) + reason = _("The certificate is not trusted."); + + if (status & GNUTLS_CERT_SIGNER_NOT_FOUND) + reason = _("The certificate hasn't got a known issuer."); + + if (status & GNUTLS_CERT_REVOKED) + reason = _("The certificate has been revoked."); + +#ifndef GNUTLS_1_0_COMPAT + if (status & GNUTLS_CERT_INSECURE_ALGORITHM) + reason = _("The certificate uses an insecure algorithm"); +#endif + + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Certificate failed validation: %s"), + reason); + goto authdeny; + } + + if (gnutls_certificate_type_get(sess->session) != GNUTLS_CRT_X509) { + virNetError(VIR_ERR_SYSTEM_ERROR, "%s", + _("Only x509 certificates are supported")); + goto authdeny; + } + + if (!(certs = gnutls_certificate_get_peers(sess->session, &nCerts))) { + virNetError(VIR_ERR_SYSTEM_ERROR, "%s", + _("The certificate has no peers")); + goto authdeny; + } + + for (i = 0; i < nCerts; i++) { + gnutls_x509_crt_t cert; + + if (gnutls_x509_crt_init(&cert) < 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, "%s", + _("Unable to initialize certificate")); + goto authfail; + } + + if (gnutls_x509_crt_import(cert, &certs[i], GNUTLS_X509_FMT_DER) < 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, "%s", + _("Unable to load certificate")); + gnutls_x509_crt_deinit(cert); + goto authfail; + } + + if (gnutls_x509_crt_get_expiration_time(cert) < now) { + virNetError(VIR_ERR_SYSTEM_ERROR, "%s", + _("The client certificate has expired")); + gnutls_x509_crt_deinit(cert); + goto authdeny; + } + + if (gnutls_x509_crt_get_activation_time(cert) > now) { + virNetError(VIR_ERR_SYSTEM_ERROR, "%s", + _("The client certificate is not yet active")); + gnutls_x509_crt_deinit(cert); + goto authdeny; + } + + if (i == 0) { + ret = gnutls_x509_crt_get_dn(cert, name, &namesize); + if (ret != 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Failed to get certificate distinguished name: %s"), + gnutls_strerror(ret)); + gnutls_x509_crt_deinit(cert); + goto authfail; + } + + if (virNetTLSContextCheckDN(ctxt, name) <= 0) { + gnutls_x509_crt_deinit(cert); + goto authdeny; + } + + if (sess->hostname && + !gnutls_x509_crt_check_hostname(cert, sess->hostname)) { + virNetError(VIR_ERR_RPC, + _("Certificate's owner does not match the hostname (%s)"), + sess->hostname); + gnutls_x509_crt_deinit(cert); + goto authdeny; + } + } + } + +#if 0 + PROBE(CLIENT_TLS_ALLOW, "fd=%d, name=%s", + virNetServerClientGetFD(client), name); +#endif + return 0; + +authdeny: +#if 0 + PROBE(CLIENT_TLS_DENY, "fd=%d, name=%s", + virNetServerClientGetFD(client), name); +#endif + return -1; + +authfail: +#if 0 + PROBE(CLIENT_TLS_FAIL, "fd=%d", + virNetServerClientGetFD(client)); +#endif + return -1; +} + +int virNetTLSContextCheckCertificate(virNetTLSContextPtr ctxt, + virNetTLSSessionPtr sess) { + if (virNetTLSContextValidCertificate(ctxt, sess) < 0) { + if (ctxt->requireValidCert) { + virNetError(VIR_ERR_AUTH_FAILED, "%s", + _("Failed to verify peer's certificate")); + return -1; + } + VIR_INFO("Ignoring bad certificate at user request"); + } + return 0; +} + +void virNetTLSContextFree(virNetTLSContextPtr ctxt) +{ + if (!ctxt) + return; + + ctxt->refs--; + if (ctxt->refs > 0) + return; + + gnutls_dh_params_deinit(ctxt->dhParams); + gnutls_certificate_free_credentials(ctxt->x509cred); + VIR_FREE(ctxt); +} + + + +static ssize_t +virNetTLSSessionPush(void *opaque, const void *buf, size_t len) +{ + virNetTLSSessionPtr sess = opaque; + if (!sess->writeFunc) { + VIR_WARN("TLS session push with missing read function"); + errno = EIO; + return -1; + }; + + return sess->writeFunc(buf, len, sess->opaque); +} + + +static ssize_t +virNetTLSSessionPull(void *opaque, void *buf, size_t len) +{ + virNetTLSSessionPtr sess = opaque; + if (!sess->readFunc) { + VIR_WARN("TLS session pull with missing read function"); + errno = EIO; + return -1; + }; + + return sess->readFunc(buf, len, sess->opaque); +} + + +virNetTLSSessionPtr virNetTLSSessionNew(virNetTLSContextPtr ctxt, + const char *hostname) +{ + virNetTLSSessionPtr sess; + int err; + static const int cert_type_priority[] = { GNUTLS_CRT_X509, 0 }; + + VIR_DEBUG("ctxt=%p hostname=%s isServer=%d", ctxt, NULLSTR(hostname), ctxt->isServer); + + if (VIR_ALLOC(sess) < 0) { + virReportOOMError(); + return NULL; + } + + sess->refs = 1; + if (hostname && + !(sess->hostname = strdup(hostname))) { + virReportOOMError(); + goto error; + } + + if ((err = gnutls_init(&sess->session, + ctxt->isServer ? GNUTLS_SERVER : GNUTLS_CLIENT)) != 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Failed to initialize TLS session: %s"), + gnutls_strerror(err)); + goto error; + } + + /* avoid calling all the priority functions, since the defaults + * are adequate. + */ + if ((err = gnutls_set_default_priority(sess->session)) != 0 || + (err = gnutls_certificate_type_set_priority(sess->session, + cert_type_priority))) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Failed to set TLS session priority %s"), + gnutls_strerror(err)); + goto error; + } + + if ((err = gnutls_credentials_set(sess->session, + GNUTLS_CRD_CERTIFICATE, + ctxt->x509cred)) != 0) { + virNetError(VIR_ERR_SYSTEM_ERROR, + _("Failed set TLS x509 credentials: %s"), + gnutls_strerror(err)); + goto error; + } + + /* request client certificate if any. + */ + if (ctxt->isServer) { + gnutls_certificate_server_set_request(sess->session, GNUTLS_CERT_REQUEST); + + gnutls_dh_set_prime_bits(sess->session, DH_BITS); + } + + gnutls_transport_set_ptr(sess->session, sess); + gnutls_transport_set_push_function(sess->session, + virNetTLSSessionPush); + gnutls_transport_set_pull_function(sess->session, + virNetTLSSessionPull); + + return sess; + +error: + virNetTLSSessionFree(sess); + return NULL; +} + + +void virNetTLSSessionRef(virNetTLSSessionPtr sess) +{ + sess->refs++; +} + +void virNetTLSSessionSetIOCallbacks(virNetTLSSessionPtr sess, + virNetTLSSessionWriteFunc writeFunc, + virNetTLSSessionReadFunc readFunc, + void *opaque) +{ + sess->writeFunc = writeFunc; + sess->readFunc = readFunc; + sess->opaque = opaque; +} + + +ssize_t virNetTLSSessionWrite(virNetTLSSessionPtr sess, + const char *buf, size_t len) +{ + ssize_t ret; + ret = gnutls_record_send(sess->session, buf, len); + + if (ret >= 0) + return ret; + + switch (ret) { + case GNUTLS_E_AGAIN: + errno = EAGAIN; + break; + case GNUTLS_E_INTERRUPTED: + errno = EINTR; + break; + default: + errno = EIO; + break; + } + + return -1; +} + +ssize_t virNetTLSSessionRead(virNetTLSSessionPtr sess, + char *buf, size_t len) +{ + ssize_t ret; + + ret = gnutls_record_recv(sess->session, buf, len); + + if (ret >= 0) + return ret; + + switch (ret) { + case GNUTLS_E_AGAIN: + errno = EAGAIN; + break; + case GNUTLS_E_INTERRUPTED: + errno = EINTR; + break; + default: + errno = EIO; + break; + } + + return -1; +} + +int virNetTLSSessionHandshake(virNetTLSSessionPtr sess) +{ + VIR_DEBUG("sess=%p", sess); + int ret = gnutls_handshake(sess->session); + VIR_DEBUG("Ret=%d", ret); + if (ret == 0) { + sess->handshakeComplete = true; + VIR_DEBUG("Handshake is complete"); + return 0; + } + if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) + return 1; + +#if 0 + PROBE(CLIENT_TLS_FAIL, "fd=%d", + virNetServerClientGetFD(client)); +#endif + + virNetError(VIR_ERR_AUTH_FAILED, + _("TLS handshake failed %s"), + gnutls_strerror(ret)); + return -1; +} + +virNetTLSSessionHandshakeStatus +virNetTLSSessionGetHandshakeStatus(virNetTLSSessionPtr sess) +{ + if (sess->handshakeComplete) + return VIR_NET_TLS_HANDSHAKE_COMPLETE; + else if (gnutls_record_get_direction(sess->session) == 0) + return VIR_NET_TLS_HANDSHAKE_RECVING; + else + return VIR_NET_TLS_HANDSHAKE_SENDING; +} + +int virNetTLSSessionGetKeySize(virNetTLSSessionPtr sess) +{ + gnutls_cipher_algorithm_t cipher; + int ssf; + + cipher = gnutls_cipher_get(sess->session); + if (!(ssf = gnutls_cipher_get_key_size(cipher))) { + virNetError(VIR_ERR_INTERNAL_ERROR, "%s", + _("invalid cipher size for TLS session")); + return -1; + } + + return ssf; +} + + +void virNetTLSSessionFree(virNetTLSSessionPtr sess) +{ + if (!sess) + return; + + sess->refs--; + if (sess->refs > 0) + return; + + VIR_FREE(sess->hostname); + gnutls_deinit(sess->session); + VIR_FREE(sess); +} diff --git a/src/rpc/virnettlscontext.h b/src/rpc/virnettlscontext.h new file mode 100644 index 0000000..f23667f --- /dev/null +++ b/src/rpc/virnettlscontext.h @@ -0,0 +1,99 @@ +/* + * virnettlscontext.h: TLS encryption/x509 handling + * + * Copyright (C) 2010-2011 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + */ + +#ifndef __VIR_NET_TLS_CONTEXT_H__ +# define __VIR_NET_TLS_CONTEXT_H__ + +# include "internal.h" + +typedef struct _virNetTLSContext virNetTLSContext; +typedef virNetTLSContext *virNetTLSContextPtr; + +typedef struct _virNetTLSSession virNetTLSSession; +typedef virNetTLSSession *virNetTLSSessionPtr; + + +virNetTLSContextPtr virNetTLSContextNewServerPath(const char *pkipath, + bool tryUserPkiPath, + const char *const*x509dnWhitelist, + bool requireValidCert); + +virNetTLSContextPtr virNetTLSContextNewClientPath(const char *pkipath, + bool tryUserPkiPath, + bool requireValidCert); + +virNetTLSContextPtr virNetTLSContextNewServer(const char *cacert, + const char *cacrl, + const char *cert, + const char *key, + const char *const*x509dnWhitelist, + bool requireValidCert); + +virNetTLSContextPtr virNetTLSContextNewClient(const char *cacert, + const char *cacrl, + const char *cert, + const char *key, + bool requireValidCert); + +void virNetTLSContextRef(virNetTLSContextPtr ctxt); + +int virNetTLSContextCheckCertificate(virNetTLSContextPtr ctxt, + virNetTLSSessionPtr sess); + +void virNetTLSContextFree(virNetTLSContextPtr ctxt); + + +typedef ssize_t (*virNetTLSSessionWriteFunc)(const char *buf, size_t len, + void *opaque); +typedef ssize_t (*virNetTLSSessionReadFunc)(char *buf, size_t len, + void *opaque); + +virNetTLSSessionPtr virNetTLSSessionNew(virNetTLSContextPtr ctxt, + const char *hostname); + +void virNetTLSSessionSetIOCallbacks(virNetTLSSessionPtr sess, + virNetTLSSessionWriteFunc writeFunc, + virNetTLSSessionReadFunc readFunc, + void *opaque); + +void virNetTLSSessionRef(virNetTLSSessionPtr sess); + +ssize_t virNetTLSSessionWrite(virNetTLSSessionPtr sess, + const char *buf, size_t len); +ssize_t virNetTLSSessionRead(virNetTLSSessionPtr sess, + char *buf, size_t len); + +int virNetTLSSessionHandshake(virNetTLSSessionPtr sess); + +typedef enum { + VIR_NET_TLS_HANDSHAKE_COMPLETE, + VIR_NET_TLS_HANDSHAKE_SENDING, + VIR_NET_TLS_HANDSHAKE_RECVING, +} virNetTLSSessionHandshakeStatus; + +virNetTLSSessionHandshakeStatus +virNetTLSSessionGetHandshakeStatus(virNetTLSSessionPtr sess); + +int virNetTLSSessionGetKeySize(virNetTLSSessionPtr sess); + +void virNetTLSSessionFree(virNetTLSSessionPtr sess); + + +#endif -- 1.7.4.4

This provides two modules for handling SASL * virNetSASLContext provides the process-wide state, currently just a whitelist of usernames on the server and a one time library init call * virNetTLSSession provides the per-connection state, ie the SASL session itself. This also include APIs for providing data encryption/decryption once the session is established * src/Makefile.am: Add to libvirt-net-rpc.la * src/rpc/virnetsaslcontext.c, src/rpc/virnetsaslcontext.h: Generic SASL handling code --- cfg.mk | 2 + po/POTFILES.in | 1 + src/Makefile.am | 9 + src/rpc/virnetsaslcontext.c | 598 +++++++++++++++++++++++++++++++++++++++++++ src/rpc/virnetsaslcontext.h | 120 +++++++++ 5 files changed, 730 insertions(+), 0 deletions(-) create mode 100644 src/rpc/virnetsaslcontext.c create mode 100644 src/rpc/virnetsaslcontext.h diff --git a/cfg.mk b/cfg.mk index 2011968..d4a7387 100644 --- a/cfg.mk +++ b/cfg.mk @@ -127,6 +127,8 @@ useless_free_options = \ --name=virLastErrFreeData \ --name=virNetMessageFree \ --name=virNetSocketFree \ + --name=virNetSASLContextFree \ + --name=virNetSASLSessionFree \ --name=virNetTLSSessionFree \ --name=virNWFilterDefFree \ --name=virNWFilterEntryFree \ diff --git a/po/POTFILES.in b/po/POTFILES.in index 55cb816..59316f1 100644 --- a/po/POTFILES.in +++ b/po/POTFILES.in @@ -68,6 +68,7 @@ src/qemu/qemu_process.c src/remote/remote_client_bodies.h src/remote/remote_driver.c src/rpc/virnetmessage.c +src/rpc/virnetsaslcontext.c src/rpc/virnetsocket.c src/rpc/virnettlscontext.c src/secret/secret_driver.c diff --git a/src/Makefile.am b/src/Makefile.am index 20ea2fb..4907806 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -1161,11 +1161,20 @@ libvirt_net_rpc_la_SOURCES = \ rpc/virnetprotocol.h rpc/virnetprotocol.c \ rpc/virnetsocket.h rpc/virnetsocket.c \ rpc/virnettlscontext.h rpc/virnettlscontext.c +if HAVE_SASL +libvirt_net_rpc_la_SOURCES += \ + rpc/virnetsaslcontext.h rpc/virnetsaslcontext.c +else +EXTRA_DIST += \ + rpc/virnetsaslcontext.h rpc/virnetsaslcontext.c +endif libvirt_net_rpc_la_CFLAGS = \ $(GNUTLS_CFLAGS) \ + $(SASL_CFLAGS) \ $(AM_CFLAGS) libvirt_net_rpc_la_LDFLAGS = \ $(GNUTLS_LIBS) \ + $(SASL_LIBS) \ $(AM_LDFLAGS) \ $(CYGWIN_EXTRA_LDFLAGS) \ $(MINGW_EXTRA_LDFLAGS) diff --git a/src/rpc/virnetsaslcontext.c b/src/rpc/virnetsaslcontext.c new file mode 100644 index 0000000..6b2a883 --- /dev/null +++ b/src/rpc/virnetsaslcontext.c @@ -0,0 +1,598 @@ +/* + * virnetsaslcontext.c: SASL encryption/auth handling + * + * Copyright (C) 2010-2011 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + */ + +#include <config.h> + +#include <fnmatch.h> + +#include "virnetsaslcontext.h" +#include "virnetmessage.h" + +#include "virterror_internal.h" +#include "memory.h" +#include "logging.h" + +#define VIR_FROM_THIS VIR_FROM_RPC +#define virNetError(code, ...) \ + virReportErrorHelper(VIR_FROM_THIS, code, __FILE__, \ + __FUNCTION__, __LINE__, __VA_ARGS__) + + +struct _virNetSASLContext { + const char *const*usernameWhitelist; + int refs; +}; + +struct _virNetSASLSession { + sasl_conn_t *conn; + int refs; + size_t maxbufsize; +}; + + +virNetSASLContextPtr virNetSASLContextNewClient(void) +{ + virNetSASLContextPtr ctxt; + int err; + + err = sasl_client_init(NULL); + if (err != SASL_OK) { + virNetError(VIR_ERR_AUTH_FAILED, + _("failed to initialize SASL library: %d (%s)"), + err, sasl_errstring(err, NULL, NULL)); + return NULL; + } + + if (VIR_ALLOC(ctxt) < 0) { + virReportOOMError(); + return NULL; + } + + ctxt->refs = 1; + + return ctxt; +} + +virNetSASLContextPtr virNetSASLContextNewServer(const char *const*usernameWhitelist) +{ + virNetSASLContextPtr ctxt; + int err; + + err = sasl_server_init(NULL, "libvirt"); + if (err != SASL_OK) { + virNetError(VIR_ERR_AUTH_FAILED, + _("failed to initialize SASL library: %d (%s)"), + err, sasl_errstring(err, NULL, NULL)); + return NULL; + } + + if (VIR_ALLOC(ctxt) < 0) { + virReportOOMError(); + return NULL; + } + + ctxt->usernameWhitelist = usernameWhitelist; + ctxt->refs = 1; + + return ctxt; +} + +int virNetSASLContextCheckIdentity(virNetSASLContextPtr ctxt, + const char *identity) +{ + const char *const*wildcards; + + /* If the list is not set, allow any DN. */ + wildcards = ctxt->usernameWhitelist; + if (!wildcards) + return 1; /* No ACL, allow all */ + + while (*wildcards) { + int ret = fnmatch (*wildcards, identity, 0); + if (ret == 0) /* Succesful match */ + return 1; + if (ret != FNM_NOMATCH) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("Malformed TLS whitelist regular expression '%s'"), + *wildcards); + return -1; + } + + wildcards++; + } + + /* Denied */ + VIR_ERROR(_("SASL client %s not allowed in whitelist"), identity); + + /* This is the most common error: make it informative. */ + virNetError(VIR_ERR_SYSTEM_ERROR, "%s", + _("Client's username is not on the list of allowed clients")); + return 0; +} + + +void virNetSASLContextRef(virNetSASLContextPtr ctxt) +{ + ctxt->refs++; +} + +void virNetSASLContextFree(virNetSASLContextPtr ctxt) +{ + if (!ctxt) + return; + + ctxt->refs--; + if (ctxt->refs > 0) + return; + + VIR_FREE(ctxt); +} + +virNetSASLSessionPtr virNetSASLSessionNewClient(virNetSASLContextPtr ctxt ATTRIBUTE_UNUSED, + const char *service, + const char *hostname, + const char *localAddr, + const char *remoteAddr, + const sasl_callback_t *cbs) +{ + virNetSASLSessionPtr sasl = NULL; + int err; + + if (VIR_ALLOC(sasl) < 0) { + virReportOOMError(); + goto cleanup; + } + + sasl->refs = 1; + /* Arbitrary size for amount of data we can encode in a single block */ + sasl->maxbufsize = 1 << 16; + + err = sasl_client_new(service, + hostname, + localAddr, + remoteAddr, + cbs, + SASL_SUCCESS_DATA, + &sasl->conn); + if (err != SASL_OK) { + virNetError(VIR_ERR_AUTH_FAILED, + _("Failed to create SASL client context: %d (%s)"), + err, sasl_errstring(err, NULL, NULL)); + goto cleanup; + } + + return sasl; + +cleanup: + virNetSASLSessionFree(sasl); + return NULL; +} + +virNetSASLSessionPtr virNetSASLSessionNewServer(virNetSASLContextPtr ctxt ATTRIBUTE_UNUSED, + const char *service, + const char *localAddr, + const char *remoteAddr) +{ + virNetSASLSessionPtr sasl = NULL; + int err; + + if (VIR_ALLOC(sasl) < 0) { + virReportOOMError(); + goto cleanup; + } + + sasl->refs = 1; + /* Arbitrary size for amount of data we can encode in a single block */ + sasl->maxbufsize = 1 << 16; + + err = sasl_server_new(service, + NULL, + NULL, + localAddr, + remoteAddr, + NULL, + SASL_SUCCESS_DATA, + &sasl->conn); + if (err != SASL_OK) { + virNetError(VIR_ERR_AUTH_FAILED, + _("Failed to create SASL client context: %d (%s)"), + err, sasl_errstring(err, NULL, NULL)); + goto cleanup; + } + + return sasl; + +cleanup: + virNetSASLSessionFree(sasl); + return NULL; +} + +void virNetSASLSessionRef(virNetSASLSessionPtr sasl) +{ + sasl->refs++; +} + +int virNetSASLSessionExtKeySize(virNetSASLSessionPtr sasl, + int ssf) +{ + int err; + + err = sasl_setprop(sasl->conn, SASL_SSF_EXTERNAL, &ssf); + if (err != SASL_OK) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("cannot set external SSF %d (%s)"), + err, sasl_errstring(err, NULL, NULL)); + return -1; + } + return 0; +} + +const char *virNetSASLSessionGetIdentity(virNetSASLSessionPtr sasl) +{ + const void *val; + int err; + + err = sasl_getprop(sasl->conn, SASL_USERNAME, &val); + if (err != SASL_OK) { + virNetError(VIR_ERR_AUTH_FAILED, + _("cannot query SASL username on connection %d (%s)"), + err, sasl_errstring(err, NULL, NULL)); + return NULL; + } + if (val == NULL) { + virNetError(VIR_ERR_AUTH_FAILED, + _("no client username was found")); + return NULL; + } + VIR_DEBUG("SASL client username %s", (const char *)val); + + return (const char*)val; +} + + +int virNetSASLSessionGetKeySize(virNetSASLSessionPtr sasl) +{ + int err; + int ssf; + const void *val; + err = sasl_getprop(sasl->conn, SASL_SSF, &val); + if (err != SASL_OK) { + virNetError(VIR_ERR_AUTH_FAILED, + _("cannot query SASL ssf on connection %d (%s)"), + err, sasl_errstring(err, NULL, NULL)); + return -1; + } + ssf = *(const int *)val; + return ssf; +} + +int virNetSASLSessionSecProps(virNetSASLSessionPtr sasl, + int minSSF, + int maxSSF, + bool allowAnonymous) +{ + sasl_security_properties_t secprops; + int err; + + VIR_DEBUG("minSSF=%d maxSSF=%d allowAnonymous=%d maxbufsize=%zu", + minSSF, maxSSF, allowAnonymous, sasl->maxbufsize); + + memset(&secprops, 0, sizeof secprops); + + secprops.min_ssf = minSSF; + secprops.max_ssf = maxSSF; + secprops.maxbufsize = sasl->maxbufsize; + secprops.security_flags = allowAnonymous ? 0 : + SASL_SEC_NOANONYMOUS | SASL_SEC_NOPLAINTEXT; + + err = sasl_setprop(sasl->conn, SASL_SEC_PROPS, &secprops); + if (err != SASL_OK) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("cannot set security props %d (%s)"), + err, sasl_errstring(err, NULL, NULL)); + return -1; + } + + return 0; +} + + +static int virNetSASLSessionUpdateBufSize(virNetSASLSessionPtr sasl) +{ + unsigned *maxbufsize; + int err; + + err = sasl_getprop(sasl->conn, SASL_MAXOUTBUF, (const void **)&maxbufsize); + if (err != SASL_OK) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("cannot get security props %d (%s)"), + err, sasl_errstring(err, NULL, NULL)); + return -1; + } + + VIR_DEBUG("Negotiated bufsize is %u vs requested size %zu", + *maxbufsize, sasl->maxbufsize); + sasl->maxbufsize = *maxbufsize; + return 0; +} + +char *virNetSASLSessionListMechanisms(virNetSASLSessionPtr sasl) +{ + const char *mechlist; + char *ret; + int err; + + err = sasl_listmech(sasl->conn, + NULL, /* Don't need to set user */ + "", /* Prefix */ + ",", /* Separator */ + "", /* Suffix */ + &mechlist, + NULL, + NULL); + if (err != SASL_OK) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("cannot list SASL mechanisms %d (%s)"), + err, sasl_errdetail(sasl->conn)); + return NULL; + } + if (!(ret = strdup(mechlist))) { + virReportOOMError(); + return NULL; + } + return ret; +} + + +int virNetSASLSessionClientStart(virNetSASLSessionPtr sasl, + const char *mechlist, + sasl_interact_t **prompt_need, + const char **clientout, + size_t *clientoutlen, + const char **mech) +{ + unsigned outlen = 0; + + VIR_DEBUG("sasl=%p mechlist=%s prompt_need=%p clientout=%p clientoutlen=%p mech=%p", + sasl, mechlist, prompt_need, clientout, clientoutlen, mech); + + int err = sasl_client_start(sasl->conn, + mechlist, + prompt_need, + clientout, + &outlen, + mech); + + *clientoutlen = outlen; + + switch (err) { + case SASL_OK: + if (virNetSASLSessionUpdateBufSize(sasl) < 0) + return -1; + return VIR_NET_SASL_COMPLETE; + case SASL_CONTINUE: + return VIR_NET_SASL_CONTINUE; + case SASL_INTERACT: + return VIR_NET_SASL_INTERACT; + + default: + virNetError(VIR_ERR_AUTH_FAILED, + _("Failed to start SASL negotiation: %d (%s)"), + err, sasl_errdetail(sasl->conn)); + return -1; + } +} + + +int virNetSASLSessionClientStep(virNetSASLSessionPtr sasl, + const char *serverin, + size_t serverinlen, + sasl_interact_t **prompt_need, + const char **clientout, + size_t *clientoutlen) +{ + unsigned inlen = serverinlen; + unsigned outlen = 0; + + VIR_DEBUG("sasl=%p serverin=%s serverinlen=%zu prompt_need=%p clientout=%p clientoutlen=%p", + sasl, serverin, serverinlen, prompt_need, clientout, clientoutlen); + + int err = sasl_client_step(sasl->conn, + serverin, + inlen, + prompt_need, + clientout, + &outlen); + *clientoutlen = outlen; + + switch (err) { + case SASL_OK: + if (virNetSASLSessionUpdateBufSize(sasl) < 0) + return -1; + return VIR_NET_SASL_COMPLETE; + case SASL_CONTINUE: + return VIR_NET_SASL_CONTINUE; + case SASL_INTERACT: + return VIR_NET_SASL_INTERACT; + + default: + virNetError(VIR_ERR_AUTH_FAILED, + _("Failed to step SASL negotiation: %d (%s)"), + err, sasl_errdetail(sasl->conn)); + return -1; + } +} + +int virNetSASLSessionServerStart(virNetSASLSessionPtr sasl, + const char *mechname, + const char *clientin, + size_t clientinlen, + const char **serverout, + size_t *serveroutlen) +{ + unsigned inlen = clientinlen; + unsigned outlen = 0; + int err = sasl_server_start(sasl->conn, + mechname, + clientin, + inlen, + serverout, + &outlen); + + *serveroutlen = outlen; + + switch (err) { + case SASL_OK: + if (virNetSASLSessionUpdateBufSize(sasl) < 0) + return -1; + return VIR_NET_SASL_COMPLETE; + case SASL_CONTINUE: + return VIR_NET_SASL_CONTINUE; + case SASL_INTERACT: + return VIR_NET_SASL_INTERACT; + + default: + virNetError(VIR_ERR_AUTH_FAILED, + _("Failed to start SASL negotiation: %d (%s)"), + err, sasl_errdetail(sasl->conn)); + return -1; + } +} + + +int virNetSASLSessionServerStep(virNetSASLSessionPtr sasl, + const char *clientin, + size_t clientinlen, + const char **serverout, + size_t *serveroutlen) +{ + unsigned inlen = clientinlen; + unsigned outlen = 0; + + int err = sasl_server_step(sasl->conn, + clientin, + inlen, + serverout, + &outlen); + + *serveroutlen = outlen; + + switch (err) { + case SASL_OK: + if (virNetSASLSessionUpdateBufSize(sasl) < 0) + return -1; + return VIR_NET_SASL_COMPLETE; + case SASL_CONTINUE: + return VIR_NET_SASL_CONTINUE; + case SASL_INTERACT: + return VIR_NET_SASL_INTERACT; + + default: + virNetError(VIR_ERR_AUTH_FAILED, + _("Failed to start SASL negotiation: %d (%s)"), + err, sasl_errdetail(sasl->conn)); + return -1; + } +} + +size_t virNetSASLSessionGetMaxBufSize(virNetSASLSessionPtr sasl) +{ + return sasl->maxbufsize; +} + +ssize_t virNetSASLSessionEncode(virNetSASLSessionPtr sasl, + const char *input, + size_t inputLen, + const char **output, + size_t *outputlen) +{ + unsigned inlen = inputLen; + unsigned outlen = 0; + int err; + + if (inputLen > sasl->maxbufsize) { + virReportSystemError(EINVAL, + _("SASL data length %zu too long, max %zu"), + inputLen, sasl->maxbufsize); + return -1; + } + + err = sasl_encode(sasl->conn, + input, + inlen, + output, + &outlen); + *outputlen = outlen; + + if (err != SASL_OK) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("failed to encode SASL data: %d (%s)"), + err, sasl_errstring(err, NULL, NULL)); + return -1; + } + return 0; +} + +ssize_t virNetSASLSessionDecode(virNetSASLSessionPtr sasl, + const char *input, + size_t inputLen, + const char **output, + size_t *outputlen) +{ + unsigned inlen = inputLen; + unsigned outlen = 0; + int err; + + if (inputLen > sasl->maxbufsize) { + virReportSystemError(EINVAL, + _("SASL data length %zu too long, max %zu"), + inputLen, sasl->maxbufsize); + return -1; + } + + err = sasl_decode(sasl->conn, + input, + inlen, + output, + &outlen); + *outputlen = outlen; + if (err != SASL_OK) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("failed to decode SASL data: %d (%s)"), + err, sasl_errstring(err, NULL, NULL)); + return -1; + } + return 0; +} + +void virNetSASLSessionFree(virNetSASLSessionPtr sasl) +{ + if (!sasl) + return; + + sasl->refs--; + if (sasl->refs > 0) + return; + + if (sasl->conn) + sasl_dispose(&sasl->conn); + + VIR_FREE(sasl); +} diff --git a/src/rpc/virnetsaslcontext.h b/src/rpc/virnetsaslcontext.h new file mode 100644 index 0000000..61a739d --- /dev/null +++ b/src/rpc/virnetsaslcontext.h @@ -0,0 +1,120 @@ +/* + * virnetsaslcontext.h: SASL encryption/auth handling + * + * Copyright (C) 2010-2011 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + */ + +#ifndef __VIR_NET_CLIENT_SASL_CONTEXT_H__ +# define __VIR_NET_CLIENT_SASL_CONTEXT_H__ + +# include <sasl/sasl.h> + +# include <stdbool.h> +# include <sys/types.h> + +typedef struct _virNetSASLContext virNetSASLContext; +typedef virNetSASLContext *virNetSASLContextPtr; + +typedef struct _virNetSASLSession virNetSASLSession; +typedef virNetSASLSession *virNetSASLSessionPtr; + +enum { + VIR_NET_SASL_COMPLETE, + VIR_NET_SASL_CONTINUE, + VIR_NET_SASL_INTERACT, +}; + +virNetSASLContextPtr virNetSASLContextNewClient(void); +virNetSASLContextPtr virNetSASLContextNewServer(const char *const*usernameWhitelist); + +int virNetSASLContextCheckIdentity(virNetSASLContextPtr ctxt, + const char *identity); + +void virNetSASLContextRef(virNetSASLContextPtr sasl); +void virNetSASLContextFree(virNetSASLContextPtr sasl); + +virNetSASLSessionPtr virNetSASLSessionNewClient(virNetSASLContextPtr ctxt, + const char *service, + const char *hostname, + const char *localAddr, + const char *remoteAddr, + const sasl_callback_t *cbs); +virNetSASLSessionPtr virNetSASLSessionNewServer(virNetSASLContextPtr ctxt, + const char *service, + const char *localAddr, + const char *remoteAddr); + +char *virNetSASLSessionListMechanisms(virNetSASLSessionPtr sasl); + +void virNetSASLSessionRef(virNetSASLSessionPtr sasl); + +int virNetSASLSessionExtKeySize(virNetSASLSessionPtr sasl, + int ssf); + +int virNetSASLSessionGetKeySize(virNetSASLSessionPtr sasl); + +const char *virNetSASLSessionGetIdentity(virNetSASLSessionPtr sasl); + +int virNetSASLSessionSecProps(virNetSASLSessionPtr sasl, + int minSSF, + int maxSSF, + bool allowAnonymous); + +int virNetSASLSessionClientStart(virNetSASLSessionPtr sasl, + const char *mechlist, + sasl_interact_t **prompt_need, + const char **clientout, + size_t *clientoutlen, + const char **mech); + +int virNetSASLSessionClientStep(virNetSASLSessionPtr sasl, + const char *serverin, + size_t serverinlen, + sasl_interact_t **prompt_need, + const char **clientout, + size_t *clientoutlen); + +int virNetSASLSessionServerStart(virNetSASLSessionPtr sasl, + const char *mechname, + const char *clientin, + size_t clientinlen, + const char **serverout, + size_t *serveroutlen); + +int virNetSASLSessionServerStep(virNetSASLSessionPtr sasl, + const char *clientin, + size_t clientinlen, + const char **serverout, + size_t *serveroutlen); + +size_t virNetSASLSessionGetMaxBufSize(virNetSASLSessionPtr sasl); + +ssize_t virNetSASLSessionEncode(virNetSASLSessionPtr sasl, + const char *input, + size_t inputLen, + const char **output, + size_t *outputlen); + +ssize_t virNetSASLSessionDecode(virNetSASLSessionPtr sasl, + const char *input, + size_t inputLen, + const char **output, + size_t *outputlen); + +void virNetSASLSessionFree(virNetSASLSessionPtr sasl); + +#endif /* __VIR_NET_CLIENT_SASL_CONTEXT_H__ */ -- 1.7.4.4

This extends the basic virNetSocket APIs to allow them to have a handle to the TLS/SASL session objects, once established. This ensures that any data reads/writes are automagically passed through the TLS/SASL encryption layers if required. * src/rpc/virnetsocket.c, src/rpc/virnetsocket.h: Wire up SASL/TLS encryption --- src/rpc/virnetsocket.c | 274 +++++++++++++++++++++++++++++++++++++++++++++++- src/rpc/virnetsocket.h | 11 ++ 2 files changed, 282 insertions(+), 3 deletions(-) diff --git a/src/rpc/virnetsocket.c b/src/rpc/virnetsocket.c index f7d1095..60ee4e5 100644 --- a/src/rpc/virnetsocket.c +++ b/src/rpc/virnetsocket.c @@ -27,6 +27,9 @@ #include <sys/socket.h> #include <unistd.h> #include <sys/wait.h> +#ifdef HAVE_NETINET_TCP_H +# include <netinet/tcp.h> +#endif #ifdef HAVE_NETINET_TCP_H # include <netinet/tcp.h> @@ -59,6 +62,19 @@ struct _virNetSocket { virSocketAddr remoteAddr; char *localAddrStr; char *remoteAddrStr; + + virNetTLSSessionPtr tlsSession; +#if HAVE_SASL + virNetSASLSessionPtr saslSession; + + const char *saslDecoded; + size_t saslDecodedLength; + size_t saslDecodedOffset; + + const char *saslEncoded; + size_t saslEncodedLength; + size_t saslEncodedOffset; +#endif }; @@ -417,7 +433,7 @@ error: } -#if HAVE_SYS_UN_H +#ifdef HAVE_SYS_UN_H int virNetSocketNewConnectUNIX(const char *path, bool spawnDaemon, const char *binary, @@ -633,6 +649,14 @@ void virNetSocketFree(virNetSocketPtr sock) unlink(sock->localAddr.data.un.sun_path); #endif + /* Make sure it can't send any more I/O during shutdown */ + if (sock->tlsSession) + virNetTLSSessionSetIOCallbacks(sock->tlsSession, NULL, NULL, NULL); + virNetTLSSessionFree(sock->tlsSession); +#if HAVE_SASL + virNetSASLSessionFree(sock->saslSession); +#endif + VIR_FORCE_CLOSE(sock->fd); VIR_FORCE_CLOSE(sock->errfd); @@ -718,14 +742,258 @@ const char *virNetSocketRemoteAddrString(virNetSocketPtr sock) return sock->remoteAddrStr; } -ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len) + +static ssize_t virNetSocketTLSSessionWrite(const char *buf, + size_t len, + void *opaque) { + virNetSocketPtr sock = opaque; + return write(sock->fd, buf, len); +} + + +static ssize_t virNetSocketTLSSessionRead(char *buf, + size_t len, + void *opaque) +{ + virNetSocketPtr sock = opaque; return read(sock->fd, buf, len); } + +void virNetSocketSetTLSSession(virNetSocketPtr sock, + virNetTLSSessionPtr sess) +{ + virNetTLSSessionFree(sock->tlsSession); + sock->tlsSession = sess; + virNetTLSSessionSetIOCallbacks(sess, + virNetSocketTLSSessionWrite, + virNetSocketTLSSessionRead, + sock); + virNetTLSSessionRef(sess); +} + + +#if HAVE_SASL +void virNetSocketSetSASLSession(virNetSocketPtr sock, + virNetSASLSessionPtr sess) +{ + virNetSASLSessionFree(sock->saslSession); + sock->saslSession = sess; + virNetSASLSessionRef(sess); +} +#endif + + +bool virNetSocketHasCachedData(virNetSocketPtr sock ATTRIBUTE_UNUSED) +{ +#if HAVE_SASL + if (sock->saslDecoded) + return true; +#endif + return false; +} + + +static ssize_t virNetSocketReadWire(virNetSocketPtr sock, char *buf, size_t len) +{ + char *errout = NULL; + ssize_t ret; +reread: + if (sock->tlsSession && + virNetTLSSessionGetHandshakeStatus(sock->tlsSession) == + VIR_NET_TLS_HANDSHAKE_COMPLETE) { + ret = virNetTLSSessionRead(sock->tlsSession, buf, len); + } else { + ret = read(sock->fd, buf, len); + } + + if ((ret < 0) && (errno == EINTR)) + goto reread; + if ((ret < 0) && (errno == EAGAIN)) + return 0; + + if (ret <= 0 && + sock->errfd != -1 && + virFileReadLimFD(sock->errfd, 1024, &errout) >= 0 && + errout != NULL) { + size_t elen = strlen(errout); + if (elen && errout[elen-1] == '\n') + errout[elen-1] = '\0'; + } + + if (ret < 0) { + if (errout) + virReportSystemError(errno, + _("Cannot recv data: %s"), errout); + else + virReportSystemError(errno, "%s", + _("Cannot recv data")); + ret = -1; + } else if (ret == 0) { + if (errout) + virReportSystemError(EIO, + _("End of file while reading data: %s"), errout); + else + virReportSystemError(EIO, "%s", + _("End of file while reading data")); + ret = -1; + } + + VIR_FREE(errout); + return ret; +} + +static ssize_t virNetSocketWriteWire(virNetSocketPtr sock, const char *buf, size_t len) +{ + ssize_t ret; +rewrite: + if (sock->tlsSession && + virNetTLSSessionGetHandshakeStatus(sock->tlsSession) == + VIR_NET_TLS_HANDSHAKE_COMPLETE) { + ret = virNetTLSSessionWrite(sock->tlsSession, buf, len); + } else { + ret = write(sock->fd, buf, len); + } + + if (ret < 0) { + if (errno == EINTR) + goto rewrite; + if (errno == EAGAIN) + return 0; + + virReportSystemError(errno, "%s", + _("Cannot write data")); + return -1; + } + if (ret == 0) { + virReportSystemError(EIO, "%s", + _("End of file while writing data")); + return -1; + } + + return ret; +} + + +#if HAVE_SASL +static ssize_t virNetSocketReadSASL(virNetSocketPtr sock, char *buf, size_t len) +{ + ssize_t got; + + /* Need to read some more data off the wire */ + if (sock->saslDecoded == NULL) { + ssize_t encodedLen = virNetSASLSessionGetMaxBufSize(sock->saslSession); + char *encoded; + if (VIR_ALLOC_N(encoded, encodedLen) < 0) { + virReportOOMError(); + return -1; + } + encodedLen = virNetSocketReadWire(sock, encoded, encodedLen); + + if (encodedLen <= 0) { + VIR_FREE(encoded); + return encodedLen; + } + + if (virNetSASLSessionDecode(sock->saslSession, + encoded, encodedLen, + &sock->saslDecoded, &sock->saslDecodedLength) < 0) { + VIR_FREE(encoded); + return -1; + } + VIR_FREE(encoded); + + sock->saslDecodedOffset = 0; + } + + /* Some buffered decoded data to return now */ + got = sock->saslDecodedLength - sock->saslDecodedOffset; + + if (len > got) + len = got; + + memcpy(buf, sock->saslDecoded + sock->saslDecodedOffset, len); + sock->saslDecodedOffset += len; + + if (sock->saslDecodedOffset == sock->saslDecodedLength) { + sock->saslDecoded = NULL; + sock->saslDecodedOffset = sock->saslDecodedLength = 0; + } + + return len; +} + + +static ssize_t virNetSocketWriteSASL(virNetSocketPtr sock, const char *buf, size_t len) +{ + int ret; + size_t tosend = virNetSASLSessionGetMaxBufSize(sock->saslSession); + + /* SASL doesn't neccessarily let us send the whole + buffer at once */ + if (tosend > len) + tosend = len; + + /* Not got any pending encoded data, so we need to encode raw stuff */ + if (sock->saslEncoded == NULL) { + if (virNetSASLSessionEncode(sock->saslSession, + buf, tosend, + &sock->saslEncoded, + &sock->saslEncodedLength) < 0) + return -1; + + sock->saslEncodedOffset = 0; + } + + /* Send some of the encoded stuff out on the wire */ + ret = virNetSocketWriteWire(sock, + sock->saslEncoded + sock->saslEncodedOffset, + sock->saslEncodedLength - sock->saslEncodedOffset); + + if (ret <= 0) + return ret; /* -1 error, 0 == egain */ + + /* Note how much we sent */ + sock->saslEncodedOffset += ret; + + /* Sent all encoded, so update raw buffer to indicate completion */ + if (sock->saslEncodedOffset == sock->saslEncodedLength) { + sock->saslEncoded = NULL; + sock->saslEncodedOffset = sock->saslEncodedLength = 0; + + /* Mark as complete, so caller detects completion */ + return tosend; + } else { + /* Still have stuff pending in saslEncoded buffer. + * Pretend to caller that we didn't send any yet. + * The caller will then retry with same buffer + * shortly, which lets us finish saslEncoded. + */ + return 0; + } +} +#endif + + +ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len) +{ +#if HAVE_SASL + if (sock->saslSession) + return virNetSocketReadSASL(sock, buf, len); + else +#endif + return virNetSocketReadWire(sock, buf, len); +} + ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len) { - return write(sock->fd, buf, len); +#if HAVE_SASL + if (sock->saslSession) + return virNetSocketWriteSASL(sock, buf, len); + else +#endif + return virNetSocketWriteWire(sock, buf, len); } diff --git a/src/rpc/virnetsocket.h b/src/rpc/virnetsocket.h index 218fe8f..59ff288 100644 --- a/src/rpc/virnetsocket.h +++ b/src/rpc/virnetsocket.h @@ -26,6 +26,10 @@ # include "network.h" # include "command.h" +# include "virnettlscontext.h" +# ifdef HAVE_SASL +# include "virnetsaslcontext.h" +# endif typedef struct _virNetSocket virNetSocket; typedef virNetSocket *virNetSocketPtr; @@ -83,6 +87,13 @@ int virNetSocketSetBlocking(virNetSocketPtr sock, ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len); ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len); +void virNetSocketSetTLSSession(virNetSocketPtr sock, + virNetTLSSessionPtr sess); +# ifdef HAVE_SASL +void virNetSocketSetSASLSession(virNetSocketPtr sock, + virNetSASLSessionPtr sess); +# endif +bool virNetSocketHasCachedData(virNetSocketPtr sock); void virNetSocketFree(virNetSocketPtr sock); const char *virNetSocketLocalAddrString(virNetSocketPtr sock); -- 1.7.4.4

To facilitate creation of new daemons providing XDR RPC services, pull alot of the libvirtd daemon code into a set of reusable objects. * virNetServer: A server contains one or more services which accept incoming clients. It maintains the list of active clients. It has a list of RPC programs which can be used by clients. When clients produce a complete RPC message, the server passes this onto the corresponding program for handling, and queues any response back with the client. * virNetServerClient: Encapsulates a single client connection. All I/O for the client is handled, reading & writing RPC messages. * virNetServerProgram: Handles processing and dispatch of RPC method calls for a single RPC (program,version). Multiple programs can be registered with the server. * virNetServerService: Encapsulates socket(s) listening for new connections. Each service listens on a single host/port, but may have multiple sockets if on a dual IPv4/6 host. Each new daemon now merely has to define the list of RPC procedures & their handlers. It does not need to deal with any network related functionality at all. --- cfg.mk | 4 + po/POTFILES.in | 3 + src/Makefile.am | 17 +- src/rpc/virnetserver.c | 714 +++++++++++++++++++++++++++++++ src/rpc/virnetserver.h | 80 ++++ src/rpc/virnetserverclient.c | 937 +++++++++++++++++++++++++++++++++++++++++ src/rpc/virnetserverclient.h | 106 +++++ src/rpc/virnetserverprogram.c | 455 ++++++++++++++++++++ src/rpc/virnetserverprogram.h | 107 +++++ src/rpc/virnetserverservice.c | 247 +++++++++++ src/rpc/virnetserverservice.h | 65 +++ 11 files changed, 2734 insertions(+), 1 deletions(-) create mode 100644 src/rpc/virnetserver.c create mode 100644 src/rpc/virnetserver.h create mode 100644 src/rpc/virnetserverclient.c create mode 100644 src/rpc/virnetserverclient.h create mode 100644 src/rpc/virnetserverprogram.c create mode 100644 src/rpc/virnetserverprogram.h create mode 100644 src/rpc/virnetserverservice.c create mode 100644 src/rpc/virnetserverservice.h diff --git a/cfg.mk b/cfg.mk index d4a7387..1cd2a0f 100644 --- a/cfg.mk +++ b/cfg.mk @@ -126,6 +126,10 @@ useless_free_options = \ --name=virJSONValueFree \ --name=virLastErrFreeData \ --name=virNetMessageFree \ + --name=virNetServerFree \ + --name=virNetServerClientFree \ + --name=virNetServerProgramFree \ + --name=virNetServerServiceFree \ --name=virNetSocketFree \ --name=virNetSASLContextFree \ --name=virNetSASLSessionFree \ diff --git a/po/POTFILES.in b/po/POTFILES.in index 59316f1..8a0e89f 100644 --- a/po/POTFILES.in +++ b/po/POTFILES.in @@ -70,6 +70,9 @@ src/remote/remote_driver.c src/rpc/virnetmessage.c src/rpc/virnetsaslcontext.c src/rpc/virnetsocket.c +src/rpc/virnetserver.c +src/rpc/virnetserverclient.c +src/rpc/virnetserverprogram.c src/rpc/virnettlscontext.c src/secret/secret_driver.c src/security/security_apparmor.c diff --git a/src/Makefile.am b/src/Makefile.am index 4907806..2b4a6e4 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -1154,7 +1154,7 @@ libvirt_qemu_la_LIBADD = libvirt.la $(CYGWIN_EXTRA_LIBADD) EXTRA_DIST += $(LIBVIRT_QEMU_SYMBOL_FILE) -noinst_LTLIBRARIES += libvirt-net-rpc.la +noinst_LTLIBRARIES += libvirt-net-rpc.la libvirt-net-rpc-server.la libvirt_net_rpc_la_SOURCES = \ rpc/virnetmessage.h rpc/virnetmessage.c \ @@ -1181,6 +1181,21 @@ libvirt_net_rpc_la_LDFLAGS = \ libvirt_net_rpc_la_LIBADD = \ $(CYGWIN_EXTRA_LIBADD) +libvirt_net_rpc_server_la_SOURCES = \ + rpc/virnetserverprogram.h rpc/virnetserverprogram.c \ + rpc/virnetserverservice.h rpc/virnetserverservice.c \ + rpc/virnetserverclient.h rpc/virnetserverclient.c \ + rpc/virnetserver.h rpc/virnetserver.c +libvirt_net_rpc_server_la_CFLAGS = \ + $(AM_CFLAGS) +libvirt_net_rpc_server_la_LDFLAGS = \ + $(AM_LDFLAGS) \ + $(CYGWIN_EXTRA_LDFLAGS) \ + $(MINGW_EXTRA_LDFLAGS)l +libvirt_net_rpc_server_la_LIBADD = \ + $(CYGWIN_EXTRA_LIBADD) + + libexec_PROGRAMS = if WITH_LIBVIRTD diff --git a/src/rpc/virnetserver.c b/src/rpc/virnetserver.c new file mode 100644 index 0000000..b71f34e --- /dev/null +++ b/src/rpc/virnetserver.c @@ -0,0 +1,714 @@ +/* + * virnetserver.c: generic network RPC server + * + * Copyright (C) 2006-2011 Red Hat, Inc. + * Copyright (C) 2006 Daniel P. Berrange + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#include <config.h> + +#include <unistd.h> +#include <string.h> + +#include "virnetserver.h" +#include "logging.h" +#include "memory.h" +#include "virterror_internal.h" +#include "threads.h" +#include "threadpool.h" +#include "util.h" +#include "files.h" +#include "event.h" + +#define VIR_FROM_THIS VIR_FROM_RPC +#define virNetError(code, ...) \ + virReportErrorHelper(VIR_FROM_THIS, code, __FILE__, \ + __FUNCTION__, __LINE__, __VA_ARGS__) + +typedef struct _virNetServerSignal virNetServerSignal; +typedef virNetServerSignal *virNetServerSignalPtr; + +struct _virNetServerSignal { + struct sigaction oldaction; + int signum; + virNetServerSignalFunc func; + void *opaque; +}; + +typedef struct _virNetServerJob virNetServerJob; +typedef virNetServerJob *virNetServerJobPtr; + +struct _virNetServerJob { + virNetServerClientPtr client; + virNetMessagePtr msg; +}; + +struct _virNetServer { + int refs; + + virMutex lock; + + virThreadPoolPtr workers; + + bool privileged; + + size_t nsignals; + virNetServerSignalPtr *signals; + int sigread; + int sigwrite; + int sigwatch; + + size_t nservices; + virNetServerServicePtr *services; + + size_t nprograms; + virNetServerProgramPtr *programs; + + size_t nclients; + size_t nclients_max; + virNetServerClientPtr *clients; + + unsigned int quit :1; + + virNetTLSContextPtr tls; + + unsigned int autoShutdownTimeout; + virNetServerAutoShutdownFunc autoShutdownFunc; + void *autoShutdownOpaque; + + virNetServerClientInitHook clientInitHook; +}; + + +static void virNetServerLock(virNetServerPtr srv) +{ + virMutexLock(&srv->lock); +} + +static void virNetServerUnlock(virNetServerPtr srv) +{ + virMutexUnlock(&srv->lock); +} + + +static void virNetServerHandleJob(void *jobOpaque, void *opaque) +{ + virNetServerPtr srv = opaque; + virNetServerJobPtr job = jobOpaque; + virNetServerProgramPtr prog = NULL; + size_t i; + + virNetServerClientRef(job->client); + + virNetServerLock(srv); + VIR_DEBUG("server=%p client=%p message=%p", + srv, job->client, job->msg); + + for (i = 0 ; i < srv->nprograms ; i++) { + if (virNetServerProgramMatches(srv->programs[i], job->msg)) { + prog = srv->programs[i]; + break; + } + } + + if (!prog) { + VIR_DEBUG("Cannot find program %d version %d", + job->msg->header.prog, + job->msg->header.vers); + goto error; + } + + virNetServerProgramRef(prog); + virNetServerUnlock(srv); + + if (virNetServerProgramDispatch(prog, + srv, + job->client, + job->msg) < 0) + goto error; + + virNetServerLock(srv); + virNetServerProgramFree(prog); + virNetServerUnlock(srv); + virNetServerClientFree(job->client); + + VIR_FREE(job); + return; + +error: + virNetServerUnlock(srv); + virNetServerProgramFree(prog); + virNetMessageFree(job->msg); + virNetServerClientClose(job->client); + virNetServerClientFree(job->client); + VIR_FREE(job); +} + + +static int virNetServerDispatchNewMessage(virNetServerClientPtr client, + virNetMessagePtr msg, + void *opaque) +{ + virNetServerPtr srv = opaque; + virNetServerJobPtr job; + int ret; + + VIR_DEBUG("server=%p client=%p message=%p", + srv, client, msg); + + if (VIR_ALLOC(job) < 0) { + virReportOOMError(); + return -1; + } + + job->client = client; + job->msg = msg; + + virNetServerLock(srv); + if ((ret = virThreadPoolSendJob(srv->workers, job)) < 0) + VIR_FREE(job); + virNetServerUnlock(srv); + + return ret; +} + + +static int virNetServerDispatchNewClient(virNetServerServicePtr svc ATTRIBUTE_UNUSED, + virNetServerClientPtr client, + void *opaque) +{ + virNetServerPtr srv = opaque; + + virNetServerLock(srv); + + if (srv->nclients >= srv->nclients_max) { + virNetError(VIR_ERR_RPC, + _("Too many active clients (%zu), dropping connection from %s"), + srv->nclients_max, virNetServerClientRemoteAddrString(client)); + goto error; + } + + if (virNetServerClientInit(client) < 0) + goto error; + + if (srv->clientInitHook && + srv->clientInitHook(srv, client) < 0) + goto error; + + if (VIR_EXPAND_N(srv->clients, srv->nclients, 1) < 0) { + virReportOOMError(); + goto error; + } + srv->clients[srv->nclients-1] = client; + virNetServerClientRef(client); + + virNetServerClientSetDispatcher(client, + virNetServerDispatchNewMessage, + srv); + + virNetServerUnlock(srv); + return 0; + +error: + virNetServerUnlock(srv); + return -1; +} + + +static void virNetServerFatalSignal(int sig, siginfo_t * siginfo ATTRIBUTE_UNUSED, + void* context ATTRIBUTE_UNUSED) +{ + struct sigaction sig_action; + int origerrno; + + origerrno = errno; + virLogEmergencyDumpAll(sig); + + /* + * If the signal is fatal, avoid looping over this handler + * by desactivating it + */ +#ifdef SIGUSR2 + if (sig != SIGUSR2) { +#endif + sig_action.sa_handler = SIG_IGN; + sigaction(sig, &sig_action, NULL); +#ifdef SIGUSR2 + } +#endif + errno = origerrno; +} + + +virNetServerPtr virNetServerNew(size_t min_workers, + size_t max_workers, + size_t max_clients, + virNetServerClientInitHook clientInitHook) +{ + virNetServerPtr srv; + struct sigaction sig_action; + + if (VIR_ALLOC(srv) < 0) { + virReportOOMError(); + return NULL; + } + + srv->refs = 1; + + if (!(srv->workers = virThreadPoolNew(min_workers, max_workers, + virNetServerHandleJob, + srv))) + goto error; + + srv->nclients_max = max_clients; + srv->sigwrite = srv->sigread = -1; + srv->clientInitHook = clientInitHook; + srv->privileged = geteuid() == 0 ? true : false; + + if (virMutexInit(&srv->lock) < 0) { + virNetError(VIR_ERR_INTERNAL_ERROR, "%s", + _("cannot initialize mutex")); + goto error; + } + + if (virEventRegisterDefaultImpl() < 0) + goto error; + + memset(&sig_action, 0, sizeof(sig_action)); + sig_action.sa_handler = SIG_IGN; + sigaction(SIGPIPE, &sig_action, NULL); + + /* + * catch fatal errors to dump a log, also hook to USR2 for dynamic + * debugging purposes or testing + */ + sig_action.sa_sigaction = virNetServerFatalSignal; + sigaction(SIGFPE, &sig_action, NULL); + sigaction(SIGSEGV, &sig_action, NULL); + sigaction(SIGILL, &sig_action, NULL); + sigaction(SIGABRT, &sig_action, NULL); +#ifdef SIGBUS + sigaction(SIGBUS, &sig_action, NULL); +#endif +#ifdef SIGUSR2 + sigaction(SIGUSR2, &sig_action, NULL); +#endif + + VIR_DEBUG("srv=%p refs=%d", srv, srv->refs); + return srv; + +error: + virNetServerFree(srv); + return NULL; +} + + +void virNetServerRef(virNetServerPtr srv) +{ + virNetServerLock(srv); + srv->refs++; + VIR_DEBUG("srv=%p refs=%d", srv, srv->refs); + virNetServerUnlock(srv); +} + + +bool virNetServerIsPrivileged(virNetServerPtr srv) +{ + bool priv; + virNetServerLock(srv); + priv = srv->privileged; + virNetServerUnlock(srv); + return priv; +} + + +void virNetServerAutoShutdown(virNetServerPtr srv, + unsigned int timeout, + virNetServerAutoShutdownFunc func, + void *opaque) +{ + virNetServerLock(srv); + + srv->autoShutdownTimeout = timeout; + srv->autoShutdownFunc = func; + srv->autoShutdownOpaque = opaque; + + virNetServerUnlock(srv); +} + + +static sig_atomic_t sigErrors = 0; +static int sigLastErrno = 0; +static int sigWrite = -1; + +static void virNetServerSignalHandler(int sig, siginfo_t * siginfo, + void* context ATTRIBUTE_UNUSED) +{ + int origerrno; + int r; + + /* set the sig num in the struct */ + siginfo->si_signo = sig; + + origerrno = errno; + r = safewrite(sigWrite, siginfo, sizeof(*siginfo)); + if (r == -1) { + sigErrors++; + sigLastErrno = errno; + } + errno = origerrno; +} + +static void +virNetServerSignalEvent(int watch, + int fd ATTRIBUTE_UNUSED, + int events ATTRIBUTE_UNUSED, + void *opaque) { + virNetServerPtr srv = opaque; + siginfo_t siginfo; + int i; + + virNetServerLock(srv); + + if (saferead(srv->sigread, &siginfo, sizeof(siginfo)) != sizeof(siginfo)) { + virReportSystemError(errno, "%s", + _("Failed to read from signal pipe")); + virEventRemoveHandle(watch); + srv->sigwatch = -1; + goto cleanup; + } + + for (i = 0 ; i < srv->nsignals ; i++) { + if (siginfo.si_signo == srv->signals[i]->signum) { + virNetServerSignalFunc func = srv->signals[i]->func; + void *funcopaque = srv->signals[i]->opaque; + virNetServerUnlock(srv); + func(srv, &siginfo, funcopaque); + return; + } + } + + virNetError(VIR_ERR_INTERNAL_ERROR, + _("Unexpected signal received: %d"), siginfo.si_signo); + +cleanup: + virNetServerUnlock(srv); +} + +static int virNetServerSignalSetup(virNetServerPtr srv) +{ + int fds[2]; + + if (srv->sigwrite != -1) + return 0; + + if (pipe(fds) < 0) { + virReportSystemError(errno, "%s", + _("Unable to create signal pipe")); + return -1; + } + + if (virSetNonBlock(fds[0]) < 0 || + virSetNonBlock(fds[1]) < 0 || + virSetCloseExec(fds[0]) < 0 || + virSetCloseExec(fds[1]) < 0) { + virReportSystemError(errno, "%s", + _("Failed to setup pipe flags")); + goto error; + } + + if ((srv->sigwatch = virEventAddHandle(fds[0], + VIR_EVENT_HANDLE_READABLE, + virNetServerSignalEvent, + srv, NULL)) < 0) { + virNetError(VIR_ERR_INTERNAL_ERROR, "%s", + _("Failed to add signal handle watch")); + goto error; + } + + srv->sigread = fds[0]; + srv->sigwrite = fds[1]; + sigWrite = fds[1]; + + return 0; + +error: + VIR_FORCE_CLOSE(fds[0]); + VIR_FORCE_CLOSE(fds[1]); + return -1; +} + +int virNetServerAddSignalHandler(virNetServerPtr srv, + int signum, + virNetServerSignalFunc func, + void *opaque) +{ + virNetServerSignalPtr sigdata; + struct sigaction sig_action; + + virNetServerLock(srv); + + if (virNetServerSignalSetup(srv) < 0) + goto error; + + if (VIR_EXPAND_N(srv->signals, srv->nsignals, 1) < 0) + goto no_memory; + + if (VIR_ALLOC(sigdata) < 0) + goto no_memory; + + sigdata->signum = signum; + sigdata->func = func; + sigdata->opaque = opaque; + + memset(&sig_action, 0, sizeof(sig_action)); + sig_action.sa_sigaction = virNetServerSignalHandler; +#ifdef SA_SIGINFO + sig_action.sa_flags = SA_SIGINFO; +#endif + sigemptyset(&sig_action.sa_mask); + + sigaction(signum, &sig_action, &sigdata->oldaction); + + srv->signals[srv->nsignals-1] = sigdata; + + virNetServerUnlock(srv); + return 0; + +no_memory: + virReportOOMError(); +error: + VIR_FREE(sigdata); + virNetServerUnlock(srv); + return -1; +} + + + +int virNetServerAddService(virNetServerPtr srv, + virNetServerServicePtr svc) +{ + virNetServerLock(srv); + + if (VIR_EXPAND_N(srv->services, srv->nservices, 1) < 0) + goto no_memory; + + srv->services[srv->nservices-1] = svc; + virNetServerServiceRef(svc); + + virNetServerServiceSetDispatcher(svc, + virNetServerDispatchNewClient, + srv); + + virNetServerUnlock(srv); + return 0; + +no_memory: + virReportOOMError(); + virNetServerUnlock(srv); + return -1; +} + +int virNetServerAddProgram(virNetServerPtr srv, + virNetServerProgramPtr prog) +{ + virNetServerLock(srv); + + if (VIR_EXPAND_N(srv->programs, srv->nprograms, 1) < 0) + goto no_memory; + + srv->programs[srv->nprograms-1] = prog; + virNetServerProgramRef(prog); + + virNetServerUnlock(srv); + return 0; + +no_memory: + virReportOOMError(); + virNetServerUnlock(srv); + return -1; +} + +int virNetServerSetTLSContext(virNetServerPtr srv, + virNetTLSContextPtr tls) +{ + srv->tls = tls; + virNetTLSContextRef(tls); + return 0; +} + + +static void virNetServerAutoShutdownTimer(int timerid ATTRIBUTE_UNUSED, + void *opaque) { + virNetServerPtr srv = opaque; + + virNetServerLock(srv); + + if (srv->autoShutdownFunc(srv, srv->autoShutdownOpaque)) { + VIR_DEBUG("Automatic shutdown triggered"); + srv->quit = 1; + } + + virNetServerUnlock(srv); +} + + +void virNetServerUpdateServices(virNetServerPtr srv, + bool enabled) +{ + int i; + + virNetServerLock(srv); + for (i = 0 ; i < srv->nservices ; i++) + virNetServerServiceToggle(srv->services[i], enabled); + + virNetServerUnlock(srv); +} + + +void virNetServerRun(virNetServerPtr srv) +{ + int timerid = -1; + int timerActive = 0; + int i; + + virNetServerLock(srv); + + if (srv->autoShutdownTimeout && + (timerid = virEventAddTimeout(-1, + virNetServerAutoShutdownTimer, + srv, NULL)) < 0) { + virNetError(VIR_ERR_INTERNAL_ERROR, "%s", + _("Failed to register shutdown timeout")); + goto cleanup; + } + + while (!srv->quit) { + /* A shutdown timeout is specified, so check + * if any drivers have active state, if not + * shutdown after timeout seconds + */ + if (srv->autoShutdownTimeout) { + if (timerActive) { + if (srv->clients) { + VIR_DEBUG("Deactivating shutdown timer %d", timerid); + virEventUpdateTimeout(timerid, -1); + timerActive = 0; + } + } else { + if (!srv->clients) { + VIR_DEBUG("Activating shutdown timer %d", timerid); + virEventUpdateTimeout(timerid, + srv->autoShutdownTimeout * 1000); + timerActive = 1; + } + } + } + + virNetServerUnlock(srv); + if (virEventRunDefaultImpl() < 0) { + virNetServerLock(srv); + VIR_DEBUG("Loop iteration error, exiting"); + break; + } + virNetServerLock(srv); + + reprocess: + for (i = 0 ; i < srv->nclients ; i++) { + if (virNetServerClientWantClose(srv->clients[i])) + virNetServerClientClose(srv->clients[i]); + if (virNetServerClientIsClosed(srv->clients[i])) { + virNetServerClientFree(srv->clients[i]); + if (srv->nclients > 1) { + memmove(srv->clients + i, + srv->clients + i + 1, + sizeof(*srv->clients) * (srv->nclients - (i + 1))); + VIR_SHRINK_N(srv->clients, srv->nclients, 1); + } else { + VIR_FREE(srv->clients); + srv->nclients = 0; + } + + goto reprocess; + } + } + } + +cleanup: + virNetServerUnlock(srv); +} + + +void virNetServerQuit(virNetServerPtr srv) +{ + virNetServerLock(srv); + + srv->quit = 1; + + virNetServerUnlock(srv); +} + +void virNetServerFree(virNetServerPtr srv) +{ + int i; + + if (!srv) + return; + + virNetServerLock(srv); + VIR_DEBUG("srv=%p refs=%d", srv, srv->refs); + srv->refs--; + if (srv->refs > 0) { + virNetServerUnlock(srv); + return; + } + + for (i = 0 ; i < srv->nservices ; i++) + virNetServerServiceToggle(srv->services[i], false); + + virThreadPoolFree(srv->workers); + + for (i = 0 ; i < srv->nsignals ; i++) { + sigaction(srv->signals[i]->signum, &srv->signals[i]->oldaction, NULL); + VIR_FREE(srv->signals[i]); + } + VIR_FREE(srv->signals); + VIR_FORCE_CLOSE(srv->sigread); + VIR_FORCE_CLOSE(srv->sigwrite); + if (srv->sigwatch > 0) + virEventRemoveHandle(srv->sigwatch); + + for (i = 0 ; i < srv->nservices ; i++) + virNetServerServiceFree(srv->services[i]); + VIR_FREE(srv->services); + + for (i = 0 ; i < srv->nprograms ; i++) + virNetServerProgramFree(srv->programs[i]); + VIR_FREE(srv->programs); + + for (i = 0 ; i < srv->nclients ; i++) { + virNetServerClientClose(srv->clients[i]); + virNetServerClientFree(srv->clients[i]); + } + VIR_FREE(srv->clients); + + virNetServerUnlock(srv); + virMutexDestroy(&srv->lock); + VIR_FREE(srv); +} diff --git a/src/rpc/virnetserver.h b/src/rpc/virnetserver.h new file mode 100644 index 0000000..d8d7c8e --- /dev/null +++ b/src/rpc/virnetserver.h @@ -0,0 +1,80 @@ +/* + * virnetserver.h: generic network RPC server + * + * Copyright (C) 2006-2011 Red Hat, Inc. + * Copyright (C) 2006 Daniel P. Berrange + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#ifndef __VIR_NET_SERVER_H__ +# define __VIR_NET_SERVER_H__ + +# include <stdbool.h> +# include <signal.h> + +# include "virnettlscontext.h" +# include "virnetserverprogram.h" +# include "virnetserverclient.h" +# include "virnetserverservice.h" + +typedef int (*virNetServerClientInitHook)(virNetServerPtr srv, + virNetServerClientPtr client); + +virNetServerPtr virNetServerNew(size_t min_workers, + size_t max_workers, + size_t max_clients, + virNetServerClientInitHook clientInitHook); + +typedef int (*virNetServerAutoShutdownFunc)(virNetServerPtr srv, void *opaque); + +void virNetServerRef(virNetServerPtr srv); + +bool virNetServerIsPrivileged(virNetServerPtr srv); + +void virNetServerAutoShutdown(virNetServerPtr srv, + unsigned int timeout, + virNetServerAutoShutdownFunc func, + void *opaque); + +typedef void (*virNetServerSignalFunc)(virNetServerPtr srv, siginfo_t *info, void *opaque); + +int virNetServerAddSignalHandler(virNetServerPtr srv, + int signum, + virNetServerSignalFunc func, + void *opaque); + +int virNetServerAddService(virNetServerPtr srv, + virNetServerServicePtr svc); + +int virNetServerAddProgram(virNetServerPtr srv, + virNetServerProgramPtr prog); + +int virNetServerSetTLSContext(virNetServerPtr srv, + virNetTLSContextPtr tls); + +void virNetServerUpdateServices(virNetServerPtr srv, + bool enabled); + +void virNetServerRun(virNetServerPtr srv); + +void virNetServerQuit(virNetServerPtr srv); + +void virNetServerFree(virNetServerPtr srv); + + +#endif diff --git a/src/rpc/virnetserverclient.c b/src/rpc/virnetserverclient.c new file mode 100644 index 0000000..327b121 --- /dev/null +++ b/src/rpc/virnetserverclient.c @@ -0,0 +1,937 @@ +/* + * virnetserverclient.c: generic network RPC server client + * + * Copyright (C) 2006-2011 Red Hat, Inc. + * Copyright (C) 2006 Daniel P. Berrange + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#include <config.h> + +#if HAVE_SASL +# include <sasl/sasl.h> +#endif + +#include "virnetserverclient.h" + +#include "logging.h" +#include "virterror_internal.h" +#include "memory.h" +#include "threads.h" + +#define VIR_FROM_THIS VIR_FROM_RPC +#define virNetError(code, ...) \ + virReportErrorHelper(VIR_FROM_THIS, code, __FILE__, \ + __FUNCTION__, __LINE__, __VA_ARGS__) + +/* Allow for filtering of incoming messages to a custom + * dispatch processing queue, instead of the workers. + * This allows for certain types of messages to be handled + * strictly "in order" + */ + +typedef struct _virNetServerClientFilter virNetServerClientFilter; +typedef virNetServerClientFilter *virNetServerClientFilterPtr; + +struct _virNetServerClientFilter { + int id; + virNetServerClientFilterFunc func; + void *opaque; + + virNetServerClientFilterPtr next; +}; + + +struct _virNetServerClient +{ + int refs; + bool wantClose; + virMutex lock; + virNetSocketPtr sock; + int auth; + bool readonly; + char *identity; + virNetTLSContextPtr tlsCtxt; + virNetTLSSessionPtr tls; +#if HAVE_SASL + virNetSASLSessionPtr sasl; +#endif + + /* Count of messages in the 'tx' queue, + * and the server worker pool queue + * ie RPC calls in progress. Does not count + * async events which are not used for + * throttling calculations */ + size_t nrequests; + size_t nrequests_max; + /* Zero or one messages being received. Zero if + * nrequests >= max_clients and throttling */ + virNetMessagePtr rx; + /* Zero or many messages waiting for transmit + * back to client, including async events */ + virNetMessagePtr tx; + + /* Filters to capture messages that would otherwise + * end up on the 'dx' queue */ + virNetServerClientFilterPtr filters; + int nextFilterID; + + virNetServerClientDispatchFunc dispatchFunc; + void *dispatchOpaque; + + void *privateData; + virNetServerClientFreeFunc privateDataFreeFunc; +}; + + +static void virNetServerClientDispatchEvent(virNetSocketPtr sock, int events, void *opaque); +static void virNetServerClientUpdateEvent(virNetServerClientPtr client); + +static void virNetServerClientLock(virNetServerClientPtr client) +{ + virMutexLock(&client->lock); +} + +static void virNetServerClientUnlock(virNetServerClientPtr client) +{ + virMutexUnlock(&client->lock); +} + + +/* + * @client: a locked client object + */ +static int +virNetServerClientCalculateHandleMode(virNetServerClientPtr client) { + int mode = 0; + + + VIR_DEBUG("tls=%p hs=%d, rx=%p tx=%p", + client->tls, + client->tls ? virNetTLSSessionGetHandshakeStatus(client->tls) : -1, + client->rx, + client->tx); + if (!client->sock || client->wantClose) + return 0; + + if (client->tls) { + switch (virNetTLSSessionGetHandshakeStatus(client->tls)) { + case VIR_NET_TLS_HANDSHAKE_RECVING: + mode |= VIR_EVENT_HANDLE_READABLE; + break; + case VIR_NET_TLS_HANDSHAKE_SENDING: + mode |= VIR_EVENT_HANDLE_WRITABLE; + break; + default: + case VIR_NET_TLS_HANDSHAKE_COMPLETE: + if (client->rx) + mode |= VIR_EVENT_HANDLE_READABLE; + if (client->tx) + mode |= VIR_EVENT_HANDLE_WRITABLE; + } + } else { + /* If there is a message on the rx queue then + * we're wanting more input */ + if (client->rx) + mode |= VIR_EVENT_HANDLE_READABLE; + + /* If there are one or more messages to send back to client, + then monitor for writability on socket */ + if (client->tx) + mode |= VIR_EVENT_HANDLE_WRITABLE; + } + VIR_DEBUG("mode=%d", mode); + return mode; +} + +/* + * @server: a locked or unlocked server object + * @client: a locked client object + */ +static int virNetServerClientRegisterEvent(virNetServerClientPtr client) +{ + int mode = virNetServerClientCalculateHandleMode(client); + + VIR_DEBUG("Registering client event callback %d", mode); + if (virNetSocketAddIOCallback(client->sock, + mode, + virNetServerClientDispatchEvent, + client) < 0) + return -1; + + return 0; +} + +/* + * @client: a locked client object + */ +static void virNetServerClientUpdateEvent(virNetServerClientPtr client) +{ + int mode; + + if (!client->sock) + return; + + mode = virNetServerClientCalculateHandleMode(client); + + virNetSocketUpdateIOCallback(client->sock, mode); +} + + +int virNetServerClientAddFilter(virNetServerClientPtr client, + virNetServerClientFilterFunc func, + void *opaque) +{ + virNetServerClientFilterPtr filter; + int ret = -1; + + virNetServerClientLock(client); + + if (VIR_ALLOC(filter) < 0) { + virReportOOMError(); + goto cleanup; + } + + filter->id = client->nextFilterID++; + filter->func = func; + filter->opaque = opaque; + + filter->next = client->filters; + client->filters = filter; + + ret = filter->id; + +cleanup: + virNetServerClientUnlock(client); + return ret; +} + + +void virNetServerClientRemoveFilter(virNetServerClientPtr client, + int filterID) +{ + virNetServerClientFilterPtr tmp, prev; + virNetServerClientLock(client); + + prev = NULL; + tmp = client->filters; + while (tmp) { + if (tmp->id == filterID) { + if (prev) + prev->next = tmp->next; + else + client->filters = tmp->next; + + VIR_FREE(tmp); + break; + } + tmp = tmp->next; + } + + virNetServerClientUnlock(client); +} + + +/* Check the client's access. */ +static int +virNetServerClientCheckAccess(virNetServerClientPtr client) +{ + virNetMessagePtr confirm; + + /* Verify client certificate. */ + if (virNetTLSContextCheckCertificate(client->tlsCtxt, client->tls) < 0) + return -1; + + if (client->tx) { + VIR_DEBUG("client had unexpected data pending tx after access check"); + return -1; + } + + if (!(confirm = virNetMessageNew())) + return -1; + + /* Checks have succeeded. Write a '\1' byte back to the client to + * indicate this (otherwise the socket is abruptly closed). + * (NB. The '\1' byte is sent in an encrypted record). + */ + confirm->bufferLength = 1; + confirm->bufferOffset = 0; + confirm->buffer[0] = '\1'; + + client->tx = confirm; + + return 0; +} + + +virNetServerClientPtr virNetServerClientNew(virNetSocketPtr sock, + int auth, + bool readonly, + virNetTLSContextPtr tls) +{ + virNetServerClientPtr client; + + VIR_DEBUG("sock=%p auth=%d tls=%p", sock, auth, tls); + + if (VIR_ALLOC(client) < 0) { + virReportOOMError(); + return NULL; + } + + if (virMutexInit(&client->lock) < 0) + goto error; + + client->refs = 1; + client->sock = sock; + client->auth = auth; + client->readonly = readonly; + client->tlsCtxt = tls; + client->nrequests_max = 10; /* XXX */ + + if (tls) + virNetTLSContextRef(tls); + + /* Prepare one for packet receive */ + if (!(client->rx = virNetMessageNew())) + goto error; + client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX; + client->nrequests = 1; + + VIR_DEBUG("client=%p refs=%d", client, client->refs); + + return client; + +error: + /* XXX ref counting is better than this */ + client->sock = NULL; /* Caller owns 'sock' upon failure */ + virNetServerClientFree(client); + return NULL; +} + +void virNetServerClientRef(virNetServerClientPtr client) +{ + virNetServerClientLock(client); + client->refs++; + VIR_DEBUG("client=%p refs=%d", client, client->refs); + virNetServerClientUnlock(client); +} + + +int virNetServerClientGetAuth(virNetServerClientPtr client) +{ + int auth; + virNetServerClientLock(client); + auth = client->auth; + virNetServerClientUnlock(client); + return auth; +} + +bool virNetServerClientGetReadonly(virNetServerClientPtr client) +{ + bool readonly; + virNetServerClientLock(client); + readonly = client->readonly; + virNetServerClientUnlock(client); + return readonly; +} + + +bool virNetServerClientHasTLSSession(virNetServerClientPtr client) +{ + bool has; + virNetServerClientLock(client); + has = client->tls ? true : false; + virNetServerClientUnlock(client); + return has; +} + +int virNetServerClientGetTLSKeySize(virNetServerClientPtr client) +{ + int size = 0; + virNetServerClientLock(client); + if (client->tls) + size = virNetTLSSessionGetKeySize(client->tls); + virNetServerClientUnlock(client); + return size; +} + +int virNetServerClientGetFD(virNetServerClientPtr client) +{ + int fd = 0; + virNetServerClientLock(client); + fd = virNetSocketGetFD(client->sock); + virNetServerClientUnlock(client); + return fd; +} + +int virNetServerClientGetLocalIdentity(virNetServerClientPtr client, + uid_t *uid, pid_t *pid) +{ + int ret; + virNetServerClientLock(client); + ret = virNetSocketGetLocalIdentity(client->sock, uid, pid); + virNetServerClientUnlock(client); + return ret; +} + +bool virNetServerClientIsSecure(virNetServerClientPtr client) +{ + bool secure = false; + virNetServerClientLock(client); + if (client->tls) + secure = true; +#if HAVE_SASL + if (client->sasl) + secure = true; +#endif + if (virNetSocketIsLocal(client->sock)) + secure = true; + virNetServerClientUnlock(client); + return secure; +} + + +#if HAVE_SASL +void virNetServerClientSetSASLSession(virNetServerClientPtr client, + virNetSASLSessionPtr sasl) +{ + /* We don't set the sasl session on the socket here + * because we need to send out the auth confirmation + * in the clear. Only once we complete the next 'tx' + * operation do we switch to SASL mode + */ + virNetServerClientLock(client); + client->sasl = sasl; + virNetSASLSessionRef(sasl); + virNetServerClientUnlock(client); +} +#endif + + +int virNetServerClientSetIdentity(virNetServerClientPtr client, + const char *identity) +{ + int ret = -1; + virNetServerClientLock(client); + if (!(client->identity = strdup(identity))) { + virReportOOMError(); + goto error; + } + ret = 0; + +error: + virNetServerClientUnlock(client); + return ret; +} + +const char *virNetServerClientGetIdentity(virNetServerClientPtr client) +{ + const char *identity; + virNetServerClientLock(client); + identity = client->identity; + virNetServerClientLock(client); + return identity; +} + +void virNetServerClientSetPrivateData(virNetServerClientPtr client, + void *opaque, + virNetServerClientFreeFunc ff) +{ + virNetServerClientLock(client); + + if (client->privateData && + client->privateDataFreeFunc) + client->privateDataFreeFunc(client->privateData); + + client->privateData = opaque; + client->privateDataFreeFunc = ff; + + virNetServerClientUnlock(client); +} + + +void *virNetServerClientGetPrivateData(virNetServerClientPtr client) +{ + void *data; + virNetServerClientLock(client); + data = client->privateData; + virNetServerClientUnlock(client); + return data; +} + + +void virNetServerClientSetDispatcher(virNetServerClientPtr client, + virNetServerClientDispatchFunc func, + void *opaque) +{ + virNetServerClientLock(client); + client->dispatchFunc = func; + client->dispatchOpaque = opaque; + virNetServerClientUnlock(client); +} + + +const char *virNetServerClientLocalAddrString(virNetServerClientPtr client) +{ + return virNetSocketLocalAddrString(client->sock); +} + + +const char *virNetServerClientRemoteAddrString(virNetServerClientPtr client) +{ + return virNetSocketRemoteAddrString(client->sock); +} + + +void virNetServerClientFree(virNetServerClientPtr client) +{ + if (!client) + return; + + virNetServerClientLock(client); + VIR_DEBUG("client=%p refs=%d", client, client->refs); + + client->refs--; + if (client->refs > 0) { + virNetServerClientUnlock(client); + return; + } + + if (client->privateData && + client->privateDataFreeFunc) + client->privateDataFreeFunc(client->privateData); + + VIR_FREE(client->identity); +#if HAVE_SASL + virNetSASLSessionFree(client->sasl); +#endif + virNetTLSSessionFree(client->tls); + virNetTLSContextFree(client->tlsCtxt); + virNetSocketFree(client->sock); + virNetServerClientUnlock(client); + virMutexDestroy(&client->lock); + VIR_FREE(client); +} + + +/* + * + * We don't free stuff here, merely disconnect the client's + * network socket & resources. + * + * Full free of the client is done later in a safe point + * where it can be guaranteed it is no longer in use + */ +void virNetServerClientClose(virNetServerClientPtr client) +{ + virNetServerClientLock(client); + VIR_DEBUG("client=%p refs=%d", client, client->refs); + if (!client->sock) { + virNetServerClientUnlock(client); + return; + } + + /* Do now, even though we don't close the socket + * until end, to ensure we don't get invoked + * again due to tls shutdown */ + if (client->sock) + virNetSocketRemoveIOCallback(client->sock); + + if (client->tls) { + virNetTLSSessionFree(client->tls); + client->tls = NULL; + } + if (client->sock) { + virNetSocketFree(client->sock); + client->sock = NULL; + } + + while (client->rx) { + virNetMessagePtr msg + = virNetMessageQueueServe(&client->rx); + virNetMessageFree(msg); + } + while (client->tx) { + virNetMessagePtr msg + = virNetMessageQueueServe(&client->tx); + virNetMessageFree(msg); + } + + virNetServerClientUnlock(client); +} + + +bool virNetServerClientIsClosed(virNetServerClientPtr client) +{ + bool closed; + virNetServerClientLock(client); + closed = client->sock == NULL ? true : false; + virNetServerClientUnlock(client); + return closed; +} + +void virNetServerClientMarkClose(virNetServerClientPtr client) +{ + virNetServerClientLock(client); + client->wantClose = true; + virNetServerClientUnlock(client); +} + +bool virNetServerClientWantClose(virNetServerClientPtr client) +{ + bool wantClose; + virNetServerClientLock(client); + wantClose = client->wantClose; + virNetServerClientUnlock(client); + return wantClose; +} + + +int virNetServerClientInit(virNetServerClientPtr client) +{ + virNetServerClientLock(client); + + if (!client->tlsCtxt) { + /* Plain socket, so prepare to read first message */ + if (virNetServerClientRegisterEvent(client) < 0) + goto error; + } else { + int ret; + + if (!(client->tls = virNetTLSSessionNew(client->tlsCtxt, + NULL))) + goto error; + + virNetSocketSetTLSSession(client->sock, + client->tls); + + /* Begin the TLS handshake. */ + ret = virNetTLSSessionHandshake(client->tls); + if (ret == 0) { + /* Unlikely, but ... Next step is to check the certificate. */ + if (virNetServerClientCheckAccess(client) < 0) + goto error; + + /* Handshake & cert check OK, so prepare to read first message */ + if (virNetServerClientRegisterEvent(client) < 0) + goto error; + } else if (ret > 0) { + /* Most likely, need to do more handshake data */ + if (virNetServerClientRegisterEvent(client) < 0) + goto error; + } else { + goto error; + } + } + + virNetServerClientUnlock(client); + return 0; + +error: + client->wantClose = true; + virNetServerClientUnlock(client); + return -1; +} + + + +/* + * Read data into buffer using wire decoding (plain or TLS) + * + * Returns: + * -1 on error or EOF + * 0 on EAGAIN + * n number of bytes + */ +static ssize_t virNetServerClientRead(virNetServerClientPtr client) +{ + ssize_t ret; + + if (client->rx->bufferLength <= client->rx->bufferOffset) { + virNetError(VIR_ERR_RPC, + _("unexpected zero/negative length request %lld"), + (long long int)(client->rx->bufferLength - client->rx->bufferOffset)); + client->wantClose = true; + return -1; + } + + ret = virNetSocketRead(client->sock, + client->rx->buffer + client->rx->bufferOffset, + client->rx->bufferLength - client->rx->bufferOffset); + + if (ret <= 0) + return ret; + + client->rx->bufferOffset += ret; + return ret; +} + + +/* + * Read data until we get a complete message to process + */ +static void virNetServerClientDispatchRead(virNetServerClientPtr client) +{ +readmore: + if (virNetServerClientRead(client) < 0) { + client->wantClose = true; + return; /* Error */ + } + + if (client->rx->bufferOffset < client->rx->bufferLength) + return; /* Still not read enough */ + + /* Either done with length word header */ + if (client->rx->bufferLength == VIR_NET_MESSAGE_LEN_MAX) { + if (virNetMessageDecodeLength(client->rx) < 0) + return; + + virNetServerClientUpdateEvent(client); + + /* Try and read payload immediately instead of going back + into poll() because chances are the data is already + waiting for us */ + goto readmore; + } else { + /* Grab the completed message */ + virNetMessagePtr msg = virNetMessageQueueServe(&client->rx); + virNetServerClientFilterPtr filter; + + /* Decode the header so we can use it for routing decisions */ + if (virNetMessageDecodeHeader(msg) < 0) { + virNetMessageFree(msg); + client->wantClose = true; + return; + } + + /* Maybe send off for queue against a filter */ + filter = client->filters; + while (filter) { + int ret = filter->func(client, msg, filter->opaque); + if (ret < 0 || ret > 0) { + virNetMessageFree(msg); + msg = NULL; + if (ret < 0) + client->wantClose = true; + break; + } + + filter = filter->next; + } + + /* Send off to for normal dispatch to workers */ + if (msg) { + if (!client->dispatchFunc || + client->dispatchFunc(client, msg, client->dispatchOpaque) < 0) { + virNetMessageFree(msg); + client->wantClose = true; + return; + } + } + + /* Possibly need to create another receive buffer */ + if (client->nrequests < client->nrequests_max) { + if (!(client->rx = virNetMessageNew())) { + client->wantClose = true; + } + client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX; + client->nrequests++; + } + virNetServerClientUpdateEvent(client); + } +} + + +/* + * Send client->tx using no encoding + * + * Returns: + * -1 on error or EOF + * 0 on EAGAIN + * n number of bytes + */ +static ssize_t virNetServerClientWrite(virNetServerClientPtr client) +{ + ssize_t ret; + + if (client->tx->bufferLength < client->tx->bufferOffset) { + virNetError(VIR_ERR_RPC, + _("unexpected zero/negative length request %lld"), + (long long int)(client->tx->bufferLength - client->tx->bufferOffset)); + client->wantClose = true; + return -1; + } + + if (client->tx->bufferLength == client->tx->bufferOffset) + return 1; + + ret = virNetSocketWrite(client->sock, + client->tx->buffer + client->tx->bufferOffset, + client->tx->bufferLength - client->tx->bufferOffset); + if (ret <= 0) + return ret; /* -1 error, 0 = egain */ + + client->tx->bufferOffset += ret; + return ret; +} + + +/* + * Process all queued client->tx messages until + * we would block on I/O + */ +static void +virNetServerClientDispatchWrite(virNetServerClientPtr client) +{ + while (client->tx) { + ssize_t ret; + + ret = virNetServerClientWrite(client); + if (ret < 0) { + client->wantClose = true; + return; + } + if (ret == 0) + return; /* Would block on write EAGAIN */ + + if (client->tx->bufferOffset == client->tx->bufferLength) { + virNetMessagePtr msg; +#if HAVE_SASL + /* Completed this 'tx' operation, so now read for all + * future rx/tx to be under a SASL SSF layer + */ + if (client->sasl) { + virNetSocketSetSASLSession(client->sock, client->sasl); + virNetSASLSessionFree(client->sasl); + client->sasl = NULL; + } +#endif + + /* Get finished msg from head of tx queue */ + msg = virNetMessageQueueServe(&client->tx); + + if (msg->header.type == VIR_NET_REPLY) { + client->nrequests--; + /* See if the recv queue is currently throttled */ + if (!client->rx && + client->nrequests < client->nrequests_max) { + /* Ready to recv more messages */ + client->rx = msg; + client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX; + msg = NULL; + client->nrequests++; + } + } + + virNetMessageFree(msg); + + virNetServerClientUpdateEvent(client); + } + } +} + +static void +virNetServerClientDispatchHandshake(virNetServerClientPtr client) +{ + int ret; + /* Continue the handshake. */ + ret = virNetTLSSessionHandshake(client->tls); + if (ret == 0) { + /* Finished. Next step is to check the certificate. */ + if (virNetServerClientCheckAccess(client) < 0) + client->wantClose = true; + else + virNetServerClientUpdateEvent(client); + } else if (ret > 0) { + /* Carry on waiting for more handshake. Update + the events just in case handshake data flow + direction has changed */ + virNetServerClientUpdateEvent (client); + } else { + /* Fatal error in handshake */ + client->wantClose = true; + } +} + +static void +virNetServerClientDispatchEvent(virNetSocketPtr sock, int events, void *opaque) +{ + virNetServerClientPtr client = opaque; + + virNetServerClientLock(client); + + if (client->sock != sock) { + virNetSocketRemoveIOCallback(sock); + virNetServerClientUnlock(client); + return; + } + + if (events & (VIR_EVENT_HANDLE_WRITABLE | + VIR_EVENT_HANDLE_READABLE)) { + if (client->tls && + virNetTLSSessionGetHandshakeStatus(client->tls) != + VIR_NET_TLS_HANDSHAKE_COMPLETE) { + virNetServerClientDispatchHandshake(client); + } else { + if (events & VIR_EVENT_HANDLE_WRITABLE) + virNetServerClientDispatchWrite(client); + if (events & VIR_EVENT_HANDLE_READABLE) + virNetServerClientDispatchRead(client); + } + } + + /* NB, will get HANGUP + READABLE at same time upon + * disconnect */ + if (events & (VIR_EVENT_HANDLE_ERROR | + VIR_EVENT_HANDLE_HANGUP)) + client->wantClose = true; + + virNetServerClientUnlock(client); +} + + +int virNetServerClientSendMessage(virNetServerClientPtr client, + virNetMessagePtr msg) +{ + int ret = -1; + VIR_DEBUG("msg=%p proc=%d len=%zu offset=%zu", + msg, msg->header.proc, + msg->bufferLength, msg->bufferOffset); + virNetServerClientLock(client); + + if (client->sock && !client->wantClose) { + virNetMessageQueuePush(&client->tx, msg); + + virNetServerClientUpdateEvent(client); + ret = 0; + } + + virNetServerClientUnlock(client); + return ret; +} + + +bool virNetServerClientNeedAuth(virNetServerClientPtr client) +{ + bool need = false; + virNetServerClientLock(client); + if (client->auth && !client->identity) + need = true; + virNetServerClientUnlock(client); + return need; +} diff --git a/src/rpc/virnetserverclient.h b/src/rpc/virnetserverclient.h new file mode 100644 index 0000000..8554590 --- /dev/null +++ b/src/rpc/virnetserverclient.h @@ -0,0 +1,106 @@ +/* + * virnetserverclient.h: generic network RPC server client + * + * Copyright (C) 2006-2011 Red Hat, Inc. + * Copyright (C) 2006 Daniel P. Berrange + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#ifndef __VIR_NET_SERVER_CLIENT_H__ +# define __VIR_NET_SERVER_CLIENT_H__ + +# include "virnetsocket.h" +# include "virnetmessage.h" + +typedef struct _virNetServerClient virNetServerClient; +typedef virNetServerClient *virNetServerClientPtr; + +typedef int (*virNetServerClientDispatchFunc)(virNetServerClientPtr client, + virNetMessagePtr msg, + void *opaque); + +typedef int (*virNetServerClientFilterFunc)(virNetServerClientPtr client, + virNetMessagePtr msg, + void *opaque); + +virNetServerClientPtr virNetServerClientNew(virNetSocketPtr sock, + int auth, + bool readonly, + virNetTLSContextPtr tls); + +int virNetServerClientAddFilter(virNetServerClientPtr client, + virNetServerClientFilterFunc func, + void *opaque); + +void virNetServerClientRemoveFilter(virNetServerClientPtr client, + int filterID); + +int virNetServerClientGetAuth(virNetServerClientPtr client); +bool virNetServerClientGetReadonly(virNetServerClientPtr client); + +bool virNetServerClientHasTLSSession(virNetServerClientPtr client); +int virNetServerClientGetTLSKeySize(virNetServerClientPtr client); + +#ifdef HAVE_SASL +void virNetServerClientSetSASLSession(virNetServerClientPtr client, + virNetSASLSessionPtr sasl); +#endif + +int virNetServerClientGetFD(virNetServerClientPtr client); + +bool virNetServerClientIsSecure(virNetServerClientPtr client); + +int virNetServerClientSetIdentity(virNetServerClientPtr client, + const char *identity); +const char *virNetServerClientGetIdentity(virNetServerClientPtr client); + +int virNetServerClientGetLocalIdentity(virNetServerClientPtr client, + uid_t *uid, pid_t *pid); + +void virNetServerClientRef(virNetServerClientPtr client); + +typedef void (*virNetServerClientFreeFunc)(void *data); + +void virNetServerClientSetPrivateData(virNetServerClientPtr client, + void *opaque, + virNetServerClientFreeFunc ff); +void *virNetServerClientGetPrivateData(virNetServerClientPtr client); + +void virNetServerClientSetDispatcher(virNetServerClientPtr client, + virNetServerClientDispatchFunc func, + void *opaque); +void virNetServerClientClose(virNetServerClientPtr client); + +bool virNetServerClientIsClosed(virNetServerClientPtr client); +void virNetServerClientMarkClose(virNetServerClientPtr client); +bool virNetServerClientWantClose(virNetServerClientPtr client); + +int virNetServerClientInit(virNetServerClientPtr client); + +const char *virNetServerClientLocalAddrString(virNetServerClientPtr client); +const char *virNetServerClientRemoteAddrString(virNetServerClientPtr client); + +int virNetServerClientSendMessage(virNetServerClientPtr client, + virNetMessagePtr msg); + +bool virNetServerClientNeedAuth(virNetServerClientPtr client); + +void virNetServerClientFree(virNetServerClientPtr client); + + +#endif /* __VIR_NET_SERVER_CLIENT_H__ */ diff --git a/src/rpc/virnetserverprogram.c b/src/rpc/virnetserverprogram.c new file mode 100644 index 0000000..0d1577a --- /dev/null +++ b/src/rpc/virnetserverprogram.c @@ -0,0 +1,455 @@ +/* + * virnetserverprogram.c: generic network RPC server program + * + * Copyright (C) 2006-2011 Red Hat, Inc. + * Copyright (C) 2006 Daniel P. Berrange + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#include <config.h> + +#include "virnetserverprogram.h" +#include "virnetserverclient.h" + +#include "memory.h" +#include "virterror_internal.h" +#include "logging.h" + +#define VIR_FROM_THIS VIR_FROM_RPC +#define virNetError(code, ...) \ + virReportErrorHelper(VIR_FROM_THIS, code, __FILE__, \ + __FUNCTION__, __LINE__, __VA_ARGS__) + +struct _virNetServerProgram { + int refs; + + unsigned program; + unsigned version; + virNetServerProgramProcPtr procs; + size_t nprocs; +}; + +virNetServerProgramPtr virNetServerProgramNew(unsigned program, + unsigned version, + virNetServerProgramProcPtr procs, + size_t nprocs) +{ + virNetServerProgramPtr prog; + + if (VIR_ALLOC(prog) < 0) { + virReportOOMError(); + return NULL; + } + + prog->refs = 1; + prog->program = program; + prog->version = version; + prog->procs = procs; + prog->nprocs = nprocs; + + VIR_DEBUG("prog=%p refs=%d", prog, prog->refs); + + return prog; +} + + +int virNetServerProgramGetID(virNetServerProgramPtr prog) +{ + return prog->program; +} + + +int virNetServerProgramGetVersion(virNetServerProgramPtr prog) +{ + return prog->version; +} + + +void virNetServerProgramRef(virNetServerProgramPtr prog) +{ + prog->refs++; + VIR_DEBUG("prog=%p refs=%d", prog, prog->refs); +} + + +int virNetServerProgramMatches(virNetServerProgramPtr prog, + virNetMessagePtr msg) +{ + if (prog->program == msg->header.prog && + prog->version == msg->header.vers) + return 1; + return 0; +} + + +static virNetServerProgramProcPtr virNetServerProgramGetProc(virNetServerProgramPtr prog, + int procedure) +{ + if (procedure < 0) + return NULL; + if (procedure >= prog->nprocs) + return NULL; + + return &prog->procs[procedure]; +} + + +static int +virNetServerProgramSendError(virNetServerProgramPtr prog, + virNetServerClientPtr client, + virNetMessagePtr msg, + virNetMessageErrorPtr rerr, + int procedure, + int type, + int serial) +{ + VIR_DEBUG("prog=%d ver=%d proc=%d type=%d serial=%d msg=%p rerr=%p", + prog->program, prog->version, procedure, type, serial, msg, rerr); + + virNetMessageSaveError(rerr); + + /* Return header. */ + msg->header.prog = prog->program; + msg->header.vers = prog->version; + msg->header.proc = procedure; + msg->header.type = type; + msg->header.serial = serial; + msg->header.status = VIR_NET_ERROR; + + if (virNetMessageEncodeHeader(msg) < 0) + goto error; + + if (virNetMessageEncodePayload(msg, (xdrproc_t)xdr_virNetMessageError, rerr) < 0) + goto error; + xdr_free((xdrproc_t)xdr_virNetMessageError, (void*)rerr); + + /* Put reply on end of tx queue to send out */ + if (virNetServerClientSendMessage(client, msg) < 0) + return -1; + + return 0; + +error: + VIR_WARN("Failed to serialize remote error '%p'", rerr); + xdr_free((xdrproc_t)xdr_virNetMessageError, (void*)rerr); + return -1; +} + + +/* + * @client: the client to send the error to + * @req: the message this error is in reply to + * + * Send an error message to the client + * + * Returns 0 if the error was sent, -1 upon fatal error + */ +int +virNetServerProgramSendReplyError(virNetServerProgramPtr prog, + virNetServerClientPtr client, + virNetMessagePtr msg, + virNetMessageErrorPtr rerr, + virNetMessageHeaderPtr req) +{ + /* + * For data streams, errors are sent back as data streams + * For method calls, errors are sent back as method replies + */ + return virNetServerProgramSendError(prog, + client, + msg, + rerr, + req->proc, + req->type == VIR_NET_STREAM ? VIR_NET_STREAM : VIR_NET_REPLY, + req->serial); +} + + +int virNetServerProgramSendStreamError(virNetServerProgramPtr prog, + virNetServerClientPtr client, + virNetMessagePtr msg, + virNetMessageErrorPtr rerr, + int procedure, + int serial) +{ + return virNetServerProgramSendError(prog, + client, + msg, + rerr, + procedure, + VIR_NET_STREAM, + serial); +} + + +static int +virNetServerProgramDispatchCall(virNetServerProgramPtr prog, + virNetServerPtr server, + virNetServerClientPtr client, + virNetMessagePtr msg); + +/* + * @server: the unlocked server object + * @client: the unlocked client object + * @msg: the complete incoming message packet, with header already decoded + * + * This function is intended to be called from worker threads + * when an incoming message is ready to be dispatched for + * execution. + * + * Upon successful return the '@msg' instance will be released + * by this function (or more often, reused to send a reply). + * Upon failure, the '@msg' must be freed by the caller. + * + * Returns 0 if the message was dispatched, -1 upon fatal error + */ +int virNetServerProgramDispatch(virNetServerProgramPtr prog, + virNetServerPtr server, + virNetServerClientPtr client, + virNetMessagePtr msg) +{ + int ret = -1; + virNetMessageError rerr; + + memset(&rerr, 0, sizeof(rerr)); + + VIR_DEBUG("prog=%d ver=%d type=%d status=%d serial=%d proc=%d", + msg->header.prog, msg->header.vers, msg->header.type, + msg->header.status, msg->header.serial, msg->header.proc); + + /* Check version, etc. */ + if (msg->header.prog != prog->program) { + virNetError(VIR_ERR_RPC, + _("program mismatch (actual %x, expected %x)"), + msg->header.prog, prog->program); + goto error; + } + + if (msg->header.vers != prog->version) { + virNetError(VIR_ERR_RPC, + _("version mismatch (actual %x, expected %x)"), + msg->header.vers, prog->version); + goto error; + } + + switch (msg->header.type) { + case VIR_NET_CALL: + ret = virNetServerProgramDispatchCall(prog, server, client, msg); + break; + + case VIR_NET_STREAM: + /* Since stream data is non-acked, async, we may continue to receive + * stream packets after we closed down a stream. Just drop & ignore + * these. + */ + VIR_INFO("Ignoring unexpected stream data serial=%d proc=%d status=%d", + msg->header.serial, msg->header.proc, msg->header.status); + virNetMessageFree(msg); + ret = 0; + break; + + default: + virNetError(VIR_ERR_RPC, + _("Unexpected message type %u"), + msg->header.type); + goto error; + } + + return ret; + +error: + ret = virNetServerProgramSendReplyError(prog, client, msg, &rerr, &msg->header); + + return ret; +} + + +/* + * @server: the unlocked server object + * @client: the unlocked client object + * @msg: the complete incoming method call, with header already decoded + * + * This method is used to dispatch an message representing an + * incoming method call from a client. It decodes the payload + * to obtain method call arguments, invokves the method and + * then sends a reply packet with the return values + * + * Returns 0 if the reply was sent, or -1 upon fatal error + */ +static int +virNetServerProgramDispatchCall(virNetServerProgramPtr prog, + virNetServerPtr server, + virNetServerClientPtr client, + virNetMessagePtr msg) +{ + char *arg = NULL; + char *ret = NULL; + int rv = -1; + virNetServerProgramProcPtr dispatcher; + virNetMessageError rerr; + + memset(&rerr, 0, sizeof(rerr)); + + if (msg->header.status != VIR_NET_OK) { + virNetError(VIR_ERR_RPC, + _("Unexpected message status %u"), + msg->header.status); + goto error; + } + + dispatcher = virNetServerProgramGetProc(prog, msg->header.proc); + + if (!dispatcher) { + virNetError(VIR_ERR_RPC, + _("unknown procedure: %d"), + msg->header.proc); + goto error; + } + + /* If client is marked as needing auth, don't allow any RPC ops + * which are except for authentication ones + */ + if (virNetServerClientNeedAuth(client) && + dispatcher->needAuth) { + /* Explicitly *NOT* calling remoteDispatchAuthError() because + we want back-compatability with libvirt clients which don't + support the VIR_ERR_AUTH_FAILED error code */ + virNetError(VIR_ERR_RPC, + "%s", _("authentication required")); + goto error; + } + + if (VIR_ALLOC_N(arg, dispatcher->arg_len) < 0) { + virReportOOMError(); + goto error; + } + if (VIR_ALLOC_N(ret, dispatcher->ret_len) < 0) { + virReportOOMError(); + goto error; + } + + if (virNetMessageDecodePayload(msg, dispatcher->arg_filter, arg) < 0) + goto error; + + /* + * When the RPC handler is called: + * + * - Server object is unlocked + * - Client object is unlocked + * + * Without locking, it is safe to use: + * + * 'args and 'ret' + */ + rv = (dispatcher->func)(server, client, &msg->header, &rerr, arg, ret); + + xdr_free(dispatcher->arg_filter, arg); + + if (rv < 0) + goto error; + + /* Return header. We're re-using same message object, so + * only need to tweak type/status fields */ + /*msg->header.prog = msg->header.prog;*/ + /*msg->header.vers = msg->header.vers;*/ + /*msg->header.proc = msg->header.proc;*/ + msg->header.type = VIR_NET_REPLY; + /*msg->header.serial = msg->header.serial;*/ + msg->header.status = VIR_NET_OK; + + if (virNetMessageEncodeHeader(msg) < 0) { + xdr_free(dispatcher->ret_filter, ret); + goto error; + } + + if (virNetMessageEncodePayload(msg, dispatcher->ret_filter, ret) < 0) { + xdr_free(dispatcher->ret_filter, ret); + goto error; + } + + xdr_free(dispatcher->ret_filter, ret); + VIR_FREE(arg); + VIR_FREE(ret); + + /* Put reply on end of tx queue to send out */ + return virNetServerClientSendMessage(client, msg); + +error: + /* Bad stuff (de-)serializing message, but we have an + * RPC error message we can send back to the client */ + rv = virNetServerProgramSendReplyError(prog, client, msg, &rerr, &msg->header); + + VIR_FREE(arg); + VIR_FREE(ret); + + return rv; +} + + +int virNetServerProgramSendStreamData(virNetServerProgramPtr prog, + virNetServerClientPtr client, + virNetMessagePtr msg, + int procedure, + int serial, + const char *data, + size_t len) +{ + VIR_DEBUG("client=%p msg=%p data=%p len=%zu", client, msg, data, len); + + /* Return header. We're reusing same message object, so + * only need to tweak type/status fields */ + msg->header.prog = prog->program; + msg->header.vers = prog->version; + msg->header.proc = procedure; + msg->header.type = VIR_NET_STREAM; + msg->header.serial = serial; + /* + * NB + * data != NULL + len > 0 => REMOTE_CONTINUE (Sending back data) + * data != NULL + len == 0 => REMOTE_CONTINUE (Sending read EOF) + * data == NULL => REMOTE_OK (Sending finish handshake confirmation) + */ + msg->header.status = data ? VIR_NET_CONTINUE : VIR_NET_OK; + + if (virNetMessageEncodeHeader(msg) < 0) + return -1; + + if (data && len) { + if (virNetMessageEncodePayloadRaw(msg, data, len) < 0) + return -1; + + VIR_DEBUG("Total %zu", msg->bufferOffset); + } + + return virNetServerClientSendMessage(client, msg); +} + + +void virNetServerProgramFree(virNetServerProgramPtr prog) +{ + if (!prog) + return; + + VIR_DEBUG("prog=%p refs=%d", prog, prog->refs); + + prog->refs--; + if (prog->refs > 0) + return; + + VIR_FREE(prog); +} diff --git a/src/rpc/virnetserverprogram.h b/src/rpc/virnetserverprogram.h new file mode 100644 index 0000000..b68a3ef --- /dev/null +++ b/src/rpc/virnetserverprogram.h @@ -0,0 +1,107 @@ +/* + * virnetserverprogram.h: generic network RPC server program + * + * Copyright (C) 2006-2011 Red Hat, Inc. + * Copyright (C) 2006 Daniel P. Berrange + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#ifndef __VIR_NET_PROGRAM_H__ +# define __VIR_NET_PROGRAM_H__ + +# include <stdbool.h> + +# include "virnetmessage.h" +# include "virnetserverclient.h" + +typedef struct _virNetServer virNetServer; +typedef virNetServer *virNetServerPtr; + +typedef struct _virNetServerService virNetServerService; +typedef virNetServerService *virNetServerServicePtr; + +typedef struct _virNetServerProgram virNetServerProgram; +typedef virNetServerProgram *virNetServerProgramPtr; + +typedef struct _virNetServerProgramProc virNetServerProgramProc; +typedef virNetServerProgramProc *virNetServerProgramProcPtr; + +typedef struct _virNetServerProgramErrorHandler virNetServerProgramErrorHander; +typedef virNetServerProgramErrorHander *virNetServerProgramErrorHanderPtr; + +typedef int (*virNetServerProgramDispatchFunc)(virNetServerPtr server, + virNetServerClientPtr client, + virNetMessageHeaderPtr hdr, + virNetMessageErrorPtr rerr, + void *args, + void *ret); + +struct _virNetServerProgramProc { + virNetServerProgramDispatchFunc func; + size_t arg_len; + xdrproc_t arg_filter; + size_t ret_len; + xdrproc_t ret_filter; + bool needAuth; +}; + +virNetServerProgramPtr virNetServerProgramNew(unsigned program, + unsigned version, + virNetServerProgramProcPtr procs, + size_t nprocs); + +int virNetServerProgramGetID(virNetServerProgramPtr prog); +int virNetServerProgramGetVersion(virNetServerProgramPtr prog); + +void virNetServerProgramRef(virNetServerProgramPtr prog); + +int virNetServerProgramMatches(virNetServerProgramPtr prog, + virNetMessagePtr msg); + +int virNetServerProgramDispatch(virNetServerProgramPtr prog, + virNetServerPtr server, + virNetServerClientPtr client, + virNetMessagePtr msg); + +int virNetServerProgramSendReplyError(virNetServerProgramPtr prog, + virNetServerClientPtr client, + virNetMessagePtr msg, + virNetMessageErrorPtr rerr, + virNetMessageHeaderPtr req); + +int virNetServerProgramSendStreamError(virNetServerProgramPtr prog, + virNetServerClientPtr client, + virNetMessagePtr msg, + virNetMessageErrorPtr rerr, + int procedure, + int serial); + +int virNetServerProgramSendStreamData(virNetServerProgramPtr prog, + virNetServerClientPtr client, + virNetMessagePtr msg, + int procedure, + int serial, + const char *data, + size_t len); + +void virNetServerProgramFree(virNetServerProgramPtr prog); + + + + +#endif /* __VIR_NET_SERVER_PROGRAM_H__ */ diff --git a/src/rpc/virnetserverservice.c b/src/rpc/virnetserverservice.c new file mode 100644 index 0000000..0cc65c3 --- /dev/null +++ b/src/rpc/virnetserverservice.c @@ -0,0 +1,247 @@ +/* + * virnetserverservice.c: generic network RPC server service + * + * Copyright (C) 2006-2011 Red Hat, Inc. + * Copyright (C) 2006 Daniel P. Berrange + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#include <config.h> + +#include "virnetserverservice.h" + +#include "memory.h" +#include "virterror_internal.h" + + +#define VIR_FROM_THIS VIR_FROM_RPC + +struct _virNetServerService { + int refs; + + size_t nsocks; + virNetSocketPtr *socks; + + int auth; + bool readonly; + + virNetTLSContextPtr tls; + + virNetServerServiceDispatchFunc dispatchFunc; + void *dispatchOpaque; +}; + + + +static void virNetServerServiceAccept(virNetSocketPtr sock, + int events ATTRIBUTE_UNUSED, + void *opaque) +{ + virNetServerServicePtr svc = opaque; + virNetServerClientPtr client = NULL; + virNetSocketPtr clientsock = NULL; + + if (virNetSocketAccept(sock, &clientsock) < 0) + goto error; + + if (!clientsock) /* Connection already went away */ + goto cleanup; + + if (!(client = virNetServerClientNew(clientsock, + svc->auth, + svc->readonly, + svc->tls))) + goto error; + + if (!svc->dispatchFunc) + goto error; + + if (svc->dispatchFunc(svc, client, svc->dispatchOpaque) < 0) + virNetServerClientClose(client); + + virNetServerClientFree(client); + +cleanup: + return; + +error: + virNetSocketFree(clientsock); +} + + +virNetServerServicePtr virNetServerServiceNewTCP(const char *nodename, + const char *service, + int auth, + bool readonly, + virNetTLSContextPtr tls) +{ + virNetServerServicePtr svc; + size_t i; + + if (VIR_ALLOC(svc) < 0) + goto no_memory; + + svc->refs = 1; + svc->auth = auth; + svc->readonly = readonly; + svc->tls = tls; + if (tls) + virNetTLSContextRef(tls); + + if (virNetSocketNewListenTCP(nodename, + service, + &svc->socks, + &svc->nsocks) < 0) + goto error; + + for (i = 0 ; i < svc->nsocks ; i++) { + if (virNetSocketListen(svc->socks[i]) < 0) + goto error; + + /* IO callback is initially disabled, until we're ready + * to deal with incoming clients */ + if (virNetSocketAddIOCallback(svc->socks[i], + 0, + virNetServerServiceAccept, + svc) < 0) + goto error; + } + + + return svc; + +no_memory: + virReportOOMError(); +error: + virNetServerServiceFree(svc); + return NULL; +} + + +virNetServerServicePtr virNetServerServiceNewUNIX(const char *path, + mode_t mask, + gid_t grp, + int auth, + bool readonly, + virNetTLSContextPtr tls) +{ + virNetServerServicePtr svc; + int i; + + if (VIR_ALLOC(svc) < 0) + goto no_memory; + + svc->refs = 1; + svc->auth = auth; + svc->readonly = readonly; + svc->tls = tls; + if (tls) + virNetTLSContextRef(tls); + + svc->nsocks = 1; + if (VIR_ALLOC_N(svc->socks, svc->nsocks) < 0) + goto no_memory; + + if (virNetSocketNewListenUNIX(path, + mask, + grp, + &svc->socks[0]) < 0) + goto error; + + for (i = 0 ; i < svc->nsocks ; i++) { + if (virNetSocketListen(svc->socks[i]) < 0) + goto error; + + /* IO callback is initially disabled, until we're ready + * to deal with incoming clients */ + if (virNetSocketAddIOCallback(svc->socks[i], + 0, + virNetServerServiceAccept, + svc) < 0) + goto error; + } + + + return svc; + +no_memory: + virReportOOMError(); +error: + virNetServerServiceFree(svc); + return NULL; +} + + +int virNetServerServiceGetAuth(virNetServerServicePtr svc) +{ + return svc->auth; +} + + +bool virNetServerServiceIsReadonly(virNetServerServicePtr svc) +{ + return svc->readonly; +} + + +void virNetServerServiceRef(virNetServerServicePtr svc) +{ + svc->refs++; +} + + +void virNetServerServiceSetDispatcher(virNetServerServicePtr svc, + virNetServerServiceDispatchFunc func, + void *opaque) +{ + svc->dispatchFunc = func; + svc->dispatchOpaque = opaque; +} + + +void virNetServerServiceFree(virNetServerServicePtr svc) +{ + int i; + + if (!svc) + return; + + svc->refs--; + if (svc->refs > 0) + return; + + for (i = 0 ; i < svc->nsocks ; i++) + virNetSocketFree(svc->socks[i]); + VIR_FREE(svc->socks); + + virNetTLSContextFree(svc->tls); + + VIR_FREE(svc); +} + +void virNetServerServiceToggle(virNetServerServicePtr svc, + bool enabled) +{ + int i; + + for (i = 0 ; i < svc->nsocks ; i++) + virNetSocketUpdateIOCallback(svc->socks[i], + enabled ? + VIR_EVENT_HANDLE_READABLE : + 0); +} diff --git a/src/rpc/virnetserverservice.h b/src/rpc/virnetserverservice.h new file mode 100644 index 0000000..b8ccd55 --- /dev/null +++ b/src/rpc/virnetserverservice.h @@ -0,0 +1,65 @@ +/* + * virnetserverservice.h: generic network RPC server service + * + * Copyright (C) 2006-2011 Red Hat, Inc. + * Copyright (C) 2006 Daniel P. Berrange + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#ifndef __VIR_NET_SERVER_SERVICE_H__ +# define __VIR_NET_SERVER_SERVICE_H__ + +# include "virnetserverprogram.h" + +enum { + VIR_NET_SERVER_SERVICE_AUTH_NONE = 0, + VIR_NET_SERVER_SERVICE_AUTH_SASL, + VIR_NET_SERVER_SERVICE_AUTH_POLKIT, +}; + +typedef int (*virNetServerServiceDispatchFunc)(virNetServerServicePtr svc, + virNetServerClientPtr client, + void *opaque); + +virNetServerServicePtr virNetServerServiceNewTCP(const char *nodename, + const char *service, + int auth, + bool readonly, + virNetTLSContextPtr tls); +virNetServerServicePtr virNetServerServiceNewUNIX(const char *path, + mode_t mask, + gid_t grp, + int auth, + bool readonly, + virNetTLSContextPtr tls); + +int virNetServerServiceGetAuth(virNetServerServicePtr svc); +bool virNetServerServiceIsReadonly(virNetServerServicePtr svc); + +void virNetServerServiceRef(virNetServerServicePtr svc); + +void virNetServerServiceSetDispatcher(virNetServerServicePtr svc, + virNetServerServiceDispatchFunc func, + void *opaque); + +void virNetServerServiceFree(virNetServerServicePtr svc); + +void virNetServerServiceToggle(virNetServerServicePtr svc, + bool enabled); + +#endif -- 1.7.4.4

Allow RPC servers to advertize themselves using MDNS, via Avahi * src/rpc/virnetserver.c, src/rpc/virnetserver.h: Allow registration of MDNS services via avahi * src/rpc/virnetserverservice.c, src/rpc/virnetserverservice.h: Add API to fetch the listen port number * src/rpc/virnetsocket.c, src/rpc/virnetsocket.h: Add API to fetch the local port number * src/rpc/virnetservermdns.c, src/rpc/virnetservermdns.h: Represent an MDNS advertisement --- cfg.mk | 3 + src/Makefile.am | 9 + src/rpc/virnetserver.c | 47 +++- src/rpc/virnetserver.h | 4 +- src/rpc/virnetservermdns.c | 616 +++++++++++++++++++++++++++++++++++++++++ src/rpc/virnetservermdns.h | 109 ++++++++ src/rpc/virnetserverservice.c | 8 + src/rpc/virnetserverservice.h | 2 + src/rpc/virnetsocket.c | 6 + src/rpc/virnetsocket.h | 2 + 10 files changed, 804 insertions(+), 2 deletions(-) create mode 100644 src/rpc/virnetservermdns.c create mode 100644 src/rpc/virnetservermdns.h diff --git a/cfg.mk b/cfg.mk index 1cd2a0f..87be171 100644 --- a/cfg.mk +++ b/cfg.mk @@ -128,6 +128,9 @@ useless_free_options = \ --name=virNetMessageFree \ --name=virNetServerFree \ --name=virNetServerClientFree \ + --name=virNetServerMDNSFree \ + --name=virNetServerMDNSEntryFree \ + --name=virNetServerMDNSGroupFree \ --name=virNetServerProgramFree \ --name=virNetServerServiceFree \ --name=virNetSocketFree \ diff --git a/src/Makefile.am b/src/Makefile.am index 2b4a6e4..298759e 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -1186,10 +1186,19 @@ libvirt_net_rpc_server_la_SOURCES = \ rpc/virnetserverservice.h rpc/virnetserverservice.c \ rpc/virnetserverclient.h rpc/virnetserverclient.c \ rpc/virnetserver.h rpc/virnetserver.c +if HAVE_AVAHI +libvirt_net_rpc_server_la_SOURCES += \ + rpc/virnetservermdns.h rpc/virnetservermdns.c +else +EXTRA_DIST += \ + rpc/virnetservermdns.h rpc/virnetservermdns.c +endif libvirt_net_rpc_server_la_CFLAGS = \ + $(AVAHI_CFLAGS) \ $(AM_CFLAGS) libvirt_net_rpc_server_la_LDFLAGS = \ $(AM_LDFLAGS) \ + $(AVAHI_LIBS) \ $(CYGWIN_EXTRA_LDFLAGS) \ $(MINGW_EXTRA_LDFLAGS)l libvirt_net_rpc_server_la_LIBADD = \ diff --git a/src/rpc/virnetserver.c b/src/rpc/virnetserver.c index b71f34e..5055f40 100644 --- a/src/rpc/virnetserver.c +++ b/src/rpc/virnetserver.c @@ -35,6 +35,9 @@ #include "util.h" #include "files.h" #include "event.h" +#if HAVE_AVAHI +#include "virnetservermdns.h" +#endif #define VIR_FROM_THIS VIR_FROM_RPC #define virNetError(code, ...) \ @@ -74,6 +77,12 @@ struct _virNetServer { int sigwrite; int sigwatch; + char *mdnsGroupName; +#if HAVE_AVAHI + virNetServerMDNSPtr mdns; + virNetServerMDNSGroupPtr mdnsGroup; +#endif + size_t nservices; virNetServerServicePtr *services; @@ -259,6 +268,7 @@ static void virNetServerFatalSignal(int sig, siginfo_t * siginfo ATTRIBUTE_UNUSE virNetServerPtr virNetServerNew(size_t min_workers, size_t max_workers, size_t max_clients, + const char *mdnsGroupName, virNetServerClientInitHook clientInitHook) { virNetServerPtr srv; @@ -281,6 +291,19 @@ virNetServerPtr virNetServerNew(size_t min_workers, srv->clientInitHook = clientInitHook; srv->privileged = geteuid() == 0 ? true : false; + if (!(srv->mdnsGroupName = strdup(mdnsGroupName))) { + virReportOOMError(); + goto error; + } +#if HAVE_AVAHI + if (srv->mdnsGroupName) { + if (!(srv->mdns = virNetServerMDNSNew())) + goto error; + if (!(srv->mdnsGroup = virNetServerMDNSAddGroup(srv->mdns, mdnsGroupName))) + goto error; + } +#endif + if (virMutexInit(&srv->lock) < 0) { virNetError(VIR_ERR_INTERNAL_ERROR, "%s", _("cannot initialize mutex")); @@ -502,13 +525,26 @@ error: int virNetServerAddService(virNetServerPtr srv, - virNetServerServicePtr svc) + virNetServerServicePtr svc, + const char *mdnsEntryName ATTRIBUTE_UNUSED) { virNetServerLock(srv); if (VIR_EXPAND_N(srv->services, srv->nservices, 1) < 0) goto no_memory; +#if HAVE_AVAHI + if (mdnsEntryName) { + int port = virNetServerServiceGetPort(svc); + virNetServerMDNSEntryPtr entry; + + if (!(entry = virNetServerMDNSAddEntry(srv->mdnsGroup, + mdnsEntryName, + port))) + goto error; + } +#endif + srv->services[srv->nservices-1] = svc; virNetServerServiceRef(svc); @@ -521,6 +557,9 @@ int virNetServerAddService(virNetServerPtr srv, no_memory: virReportOOMError(); +#if HAVE_AVAHI +error: +#endif virNetServerUnlock(srv); return -1; } @@ -590,6 +629,12 @@ void virNetServerRun(virNetServerPtr srv) virNetServerLock(srv); +#if HAVE_AVAHI + if (srv->mdns && + virNetServerMDNSStart(srv->mdns) < 0) + goto cleanup; +#endif + if (srv->autoShutdownTimeout && (timerid = virEventAddTimeout(-1, virNetServerAutoShutdownTimer, diff --git a/src/rpc/virnetserver.h b/src/rpc/virnetserver.h index d8d7c8e..ed09ecf 100644 --- a/src/rpc/virnetserver.h +++ b/src/rpc/virnetserver.h @@ -38,6 +38,7 @@ typedef int (*virNetServerClientInitHook)(virNetServerPtr srv, virNetServerPtr virNetServerNew(size_t min_workers, size_t max_workers, size_t max_clients, + const char *mdnsGroupName, virNetServerClientInitHook clientInitHook); typedef int (*virNetServerAutoShutdownFunc)(virNetServerPtr srv, void *opaque); @@ -59,7 +60,8 @@ int virNetServerAddSignalHandler(virNetServerPtr srv, void *opaque); int virNetServerAddService(virNetServerPtr srv, - virNetServerServicePtr svc); + virNetServerServicePtr svc, + const char *mdnsEntryName); int virNetServerAddProgram(virNetServerPtr srv, virNetServerProgramPtr prog); diff --git a/src/rpc/virnetservermdns.c b/src/rpc/virnetservermdns.c new file mode 100644 index 0000000..23958d2 --- /dev/null +++ b/src/rpc/virnetservermdns.c @@ -0,0 +1,616 @@ +/* + * virnetservermdns.c: advertise server sockets + * + * Copyright (C) 2011 Red Hat, Inc. + * Copyright (C) 2007 Daniel P. Berrange + * + * Derived from Avahi example service provider code. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#include <config.h> + +#include <time.h> +#include <stdio.h> +#include <stdlib.h> + +#include <avahi-client/client.h> +#include <avahi-client/publish.h> + +#include <avahi-common/alternative.h> +#include <avahi-common/simple-watch.h> +#include <avahi-common/malloc.h> +#include <avahi-common/error.h> +#include <avahi-common/timeval.h> + +#include "virnetservermdns.h" +#include "event.h" +#include "event_poll.h" +#include "memory.h" +#include "virterror_internal.h" +#include "logging.h" + +#define VIR_FROM_THIS VIR_FROM_RPC +#define virNetError(code, ...) \ + virReportErrorHelper(VIR_FROM_THIS, code, __FILE__, \ + __FUNCTION__, __LINE__, __VA_ARGS__) + +struct _virNetServerMDNSEntry { + char *type; + int port; + virNetServerMDNSEntryPtr next; +}; + +struct _virNetServerMDNSGroup { + virNetServerMDNSPtr mdns; + AvahiEntryGroup *handle; + char *name; + virNetServerMDNSEntryPtr entry; + virNetServerMDNSGroupPtr next; +}; + +struct _virNetServerMDNS { + AvahiClient *client; + AvahiPoll *poller; + virNetServerMDNSGroupPtr group; +}; + +/* Avahi API requires this struct name in the app :-( */ +struct AvahiWatch { + int watch; + int fd; + int revents; + AvahiWatchCallback callback; + void *userdata; +}; + +/* Avahi API requires this struct name in the app :-( */ +struct AvahiTimeout { + int timer; + AvahiTimeoutCallback callback; + void *userdata; +}; + +static void virNetServerMDNSCreateServices(virNetServerMDNSGroupPtr group); + +/* Called whenever the entry group state changes */ +static void virNetServerMDNSGroupCallback(AvahiEntryGroup *g ATTRIBUTE_UNUSED, + AvahiEntryGroupState state, + void *data) +{ + virNetServerMDNSGroupPtr group = data; + + switch (state) { + case AVAHI_ENTRY_GROUP_ESTABLISHED: + /* The entry group has been established successfully */ + VIR_DEBUG("Group '%s' established", group->name); + break; + + case AVAHI_ENTRY_GROUP_COLLISION: + { + char *n; + + /* A service name collision happened. Let's pick a new name */ + n = avahi_alternative_service_name(group->name); + VIR_FREE(group->name); + group->name = n; + + VIR_DEBUG("Group name collision, renaming service to '%s'", group->name); + + /* And recreate the services */ + virNetServerMDNSCreateServices(group); + } + break; + + case AVAHI_ENTRY_GROUP_FAILURE : + VIR_DEBUG("Group failure: %s", + avahi_strerror(avahi_client_errno(group->mdns->client))); + + /* Some kind of failure happened while we were registering our services */ + //avahi_simple_poll_quit(simple_poll); + break; + + case AVAHI_ENTRY_GROUP_UNCOMMITED: + case AVAHI_ENTRY_GROUP_REGISTERING: + ; + } +} + +static void virNetServerMDNSCreateServices(virNetServerMDNSGroupPtr group) +{ + virNetServerMDNSPtr mdns = group->mdns; + virNetServerMDNSEntryPtr entry; + int ret; + VIR_DEBUG("Adding services to '%s'", group->name); + + /* If we've no services to advertise, just reset the group to make + * sure it is emptied of any previously advertised services */ + if (!group->entry) { + if (group->handle) + avahi_entry_group_reset(group->handle); + return; + } + + /* If this is the first time we're called, let's create a new entry group */ + if (!group->handle) { + VIR_DEBUG("Creating initial group %s", group->name); + if (!(group->handle = + avahi_entry_group_new(mdns->client, + virNetServerMDNSGroupCallback, + group))) { + VIR_DEBUG("avahi_entry_group_new() failed: %s", + avahi_strerror(avahi_client_errno(mdns->client))); + return; + } + } + + entry = group->entry; + while (entry) { + if ((ret = avahi_entry_group_add_service(group->handle, + AVAHI_IF_UNSPEC, + AVAHI_PROTO_UNSPEC, + 0, + group->name, + entry->type, + NULL, + NULL, + entry->port, + NULL)) < 0) { + VIR_DEBUG("Failed to add %s service on port %d: %s", + entry->type, entry->port, avahi_strerror(ret)); + avahi_entry_group_reset(group->handle); + return; + } + entry = entry->next; + } + + /* Tell the server to register the service */ + if ((ret = avahi_entry_group_commit(group->handle)) < 0) { + avahi_entry_group_reset(group->handle); + VIR_DEBUG("Failed to commit entry_group: %s", + avahi_strerror(ret)); + return; + } +} + + +static void virNetServerMDNSClientCallback(AvahiClient *c, + AvahiClientState state, + void *data) +{ + virNetServerMDNSPtr mdns = data; + virNetServerMDNSGroupPtr group; + if (!mdns->client) + mdns->client = c; + + VIR_DEBUG("Callback state=%d", state); + + /* Called whenever the client or server state changes */ + switch (state) { + case AVAHI_CLIENT_S_RUNNING: + /* The server has startup successfully and registered its host + * name on the network, so it's time to create our services */ + VIR_DEBUG("Client running %p", mdns->client); + group = mdns->group; + while (group) { + virNetServerMDNSCreateServices(group); + group = group->next; + } + break; + + case AVAHI_CLIENT_FAILURE: + VIR_DEBUG("Client failure: %s", + avahi_strerror(avahi_client_errno(c))); + virNetServerMDNSStop(mdns); + virNetServerMDNSStart(mdns); + break; + + case AVAHI_CLIENT_S_COLLISION: + /* Let's drop our registered services. When the server is back + * in AVAHI_SERVER_RUNNING state we will register them + * again with the new host name. */ + + /* Fallthrough */ + + case AVAHI_CLIENT_S_REGISTERING: + /* The server records are now being established. This + * might be caused by a host name change. We need to wait + * for our own records to register until the host name is + * properly established. */ + VIR_DEBUG("Client collision/connecting %p", mdns->client); + group = mdns->group; + while (group) { + if (group->handle) + avahi_entry_group_reset(group->handle); + group = group->next; + } + break; + + case AVAHI_CLIENT_CONNECTING: + VIR_DEBUG("Client connecting.... %p", mdns->client); + ; + } +} + + +static void virNetServerMDNSWatchDispatch(int watch, int fd, int events, void *opaque) +{ + AvahiWatch *w = opaque; + int fd_events = virEventPollToNativeEvents(events); + VIR_DEBUG("Dispatch watch %d FD %d Event %d", watch, fd, fd_events); + w->revents = fd_events; + w->callback(w, fd, fd_events, w->userdata); +} + +static void virNetServerMDNSWatchDofree(void *w) +{ + VIR_FREE(w); +} + + +static AvahiWatch *virNetServerMDNSWatchNew(const AvahiPoll *api ATTRIBUTE_UNUSED, + int fd, AvahiWatchEvent event, + AvahiWatchCallback cb, void *userdata) +{ + AvahiWatch *w; + virEventHandleType hEvents; + if (VIR_ALLOC(w) < 0) { + virReportOOMError(); + return NULL; + } + + w->fd = fd; + w->revents = 0; + w->callback = cb; + w->userdata = userdata; + + VIR_DEBUG("New handle %p FD %d Event %d", w, w->fd, event); + hEvents = virEventPollFromNativeEvents(event); + if ((w->watch = virEventAddHandle(fd, hEvents, + virNetServerMDNSWatchDispatch, + w, + virNetServerMDNSWatchDofree)) < 0) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("Failed to add watch for fd %d events %d"), fd, hEvents); + VIR_FREE(w); + return NULL; + } + + return w; +} + +static void virNetServerMDNSWatchUpdate(AvahiWatch *w, AvahiWatchEvent event) +{ + VIR_DEBUG("Update handle %p FD %d Event %d", w, w->fd, event); + virEventUpdateHandle(w->watch, event); +} + +static AvahiWatchEvent virNetServerMDNSWatchGetEvents(AvahiWatch *w) +{ + VIR_DEBUG("Get handle events %p %d", w, w->fd); + return w->revents; +} + +static void virNetServerMDNSWatchFree(AvahiWatch *w) +{ + VIR_DEBUG("Free handle %p %d", w, w->fd); + virEventRemoveHandle(w->watch); +} + +static void virNetServerMDNSTimeoutDispatch(int timer ATTRIBUTE_UNUSED, void *opaque) +{ + AvahiTimeout *t = (AvahiTimeout*)opaque; + VIR_DEBUG("Dispatch timeout %p %d", t, timer); + virEventUpdateTimeout(t->timer, -1); + t->callback(t, t->userdata); +} + +static void virNetServerMDNSTimeoutDofree(void *t) +{ + VIR_FREE(t); +} + +static AvahiTimeout *virNetServerMDNSTimeoutNew(const AvahiPoll *api ATTRIBUTE_UNUSED, + const struct timeval *tv, + AvahiTimeoutCallback cb, + void *userdata) +{ + AvahiTimeout *t; + struct timeval now; + long long nowms, thenms, timeout; + VIR_DEBUG("Add timeout TV %p", tv); + if (VIR_ALLOC(t) < 0) { + virReportOOMError(); + return NULL; + } + + if (gettimeofday(&now, NULL) < 0) { + virReportSystemError(errno, "%s", + _("Unable to get current time")); + VIR_FREE(t); + return NULL; + } + + VIR_DEBUG("Trigger timed for %d %d %d %d", + (int)now.tv_sec, (int)now.tv_usec, + (int)(tv ? tv->tv_sec : 0), (int)(tv ? tv->tv_usec : 0)); + nowms = (now.tv_sec * 1000ll) + (now.tv_usec / 1000ll); + if (tv) { + thenms = (tv->tv_sec * 1000ll) + (tv->tv_usec/1000ll); + timeout = thenms > nowms ? nowms - thenms : 0; + if (timeout < 0) + timeout = 0; + } else { + timeout = -1; + } + + t->timer = virEventAddTimeout(timeout, + virNetServerMDNSTimeoutDispatch, + t, + virNetServerMDNSTimeoutDofree); + t->callback = cb; + t->userdata = userdata; + + if (t->timer < 0) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("Failed to add timer with timeout %d"), (int)timeout); + VIR_FREE(t); + return NULL; + } + + return t; +} + +static void virNetServerMDNSTimeoutUpdate(AvahiTimeout *t, const struct timeval *tv) +{ + struct timeval now; + long long nowms, thenms, timeout; + VIR_DEBUG("Update timeout %p TV %p", t, tv); + if (gettimeofday(&now, NULL) < 0) { + VIR_FREE(t); + return; + } + + nowms = (now.tv_sec * 1000ll) + (now.tv_usec / 1000ll); + if (tv) { + thenms = ((tv->tv_sec * 1000ll) + (tv->tv_usec/1000ll)); + timeout = thenms > nowms ? nowms - thenms : 0; + if (timeout < 0) + timeout = 0; + } else { + timeout = -1; + } + + virEventUpdateTimeout(t->timer, timeout); +} + +static void virNetServerMDNSTimeoutFree(AvahiTimeout *t) +{ + VIR_DEBUG("Free timeout %p", t); + virEventRemoveTimeout(t->timer); +} + + +static AvahiPoll *virNetServerMDNSCreatePoll(void) +{ + AvahiPoll *p; + if (VIR_ALLOC(p) < 0) { + virReportOOMError(); + return NULL; + } + + p->userdata = NULL; + + p->watch_new = virNetServerMDNSWatchNew; + p->watch_update = virNetServerMDNSWatchUpdate; + p->watch_get_events = virNetServerMDNSWatchGetEvents; + p->watch_free = virNetServerMDNSWatchFree; + + p->timeout_new = virNetServerMDNSTimeoutNew; + p->timeout_update = virNetServerMDNSTimeoutUpdate; + p->timeout_free = virNetServerMDNSTimeoutFree; + + return p; +} + + +virNetServerMDNS *virNetServerMDNSNew(void) +{ + virNetServerMDNS *mdns; + if (VIR_ALLOC(mdns) < 0) + return NULL; + + /* Allocate main loop object */ + if (!(mdns->poller = virNetServerMDNSCreatePoll())) { + VIR_FREE(mdns); + return NULL; + } + + return mdns; +} + + +int virNetServerMDNSStart(virNetServerMDNS *mdns) +{ + int error; + VIR_DEBUG("Starting client %p", mdns); + mdns->client = avahi_client_new(mdns->poller, + AVAHI_CLIENT_NO_FAIL, + virNetServerMDNSClientCallback, + mdns, &error); + + if (!mdns->client) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("Failed to create mDNS client: %s"), + avahi_strerror(error)); + return -1; + } + + return 0; +} + + +virNetServerMDNSGroupPtr virNetServerMDNSAddGroup(virNetServerMDNS *mdns, + const char *name) +{ + virNetServerMDNSGroupPtr group; + + VIR_DEBUG("Adding group '%s'", name); + if (VIR_ALLOC(group) < 0) { + virReportOOMError(); + return NULL; + } + + if (!(group->name = strdup(name))) { + VIR_FREE(group); + virReportOOMError(); + return NULL; + } + group->mdns = mdns; + group->next = mdns->group; + mdns->group = group; + return group; +} + + +void virNetServerMDNSRemoveGroup(virNetServerMDNSPtr mdns, + virNetServerMDNSGroupPtr group) +{ + virNetServerMDNSGroupPtr tmp = mdns->group, prev = NULL; + + while (tmp) { + if (tmp == group) { + VIR_FREE(group->name); + if (prev) + prev->next = group->next; + else + group->mdns->group = group->next; + VIR_FREE(group); + return; + } + prev = tmp; + tmp = tmp->next; + } +} + + +virNetServerMDNSEntryPtr virNetServerMDNSAddEntry(virNetServerMDNSGroupPtr group, + const char *type, + int port) +{ + virNetServerMDNSEntryPtr entry; + + VIR_DEBUG("Adding entry %s %d to group %s", type, port, group->name); + if (VIR_ALLOC(entry) < 0) { + virReportOOMError(); + return NULL; + } + + entry->port = port; + if (!(entry->type = strdup(type))) { + VIR_FREE(entry); + virReportOOMError(); + return NULL; + } + entry->next = group->entry; + group->entry = entry; + return entry; +} + + +void virNetServerMDNSRemoveEntry(virNetServerMDNSGroupPtr group, + virNetServerMDNSEntryPtr entry) +{ + virNetServerMDNSEntryPtr tmp = group->entry, prev = NULL; + + while (tmp) { + if (tmp == entry) { + VIR_FREE(entry->type); + if (prev) + prev->next = entry->next; + else + group->entry = entry->next; + return; + } + prev = tmp; + tmp = tmp->next; + } +} + + +void virNetServerMDNSStop(virNetServerMDNSPtr mdns) +{ + virNetServerMDNSGroupPtr group = mdns->group; + while (group) { + if (group->handle) { + avahi_entry_group_free(group->handle); + group->handle = NULL; + } + group = group->next; + } + if (mdns->client) + avahi_client_free(mdns->client); + mdns->client = NULL; +} + + +void virNetServerMDNSFree(virNetServerMDNSPtr mdns) +{ + virNetServerMDNSGroupPtr group, tmp; + + if (!mdns) + return; + + group = mdns->group; + while (group) { + tmp = group->next; + virNetServerMDNSGroupFree(group); + group = tmp; + } + + VIR_FREE(mdns); +} + + +void virNetServerMDNSGroupFree(virNetServerMDNSGroupPtr grp) +{ + virNetServerMDNSEntryPtr entry, tmp; + + if (!grp) + return; + + entry = grp->entry; + while (entry) { + tmp = entry->next; + virNetServerMDNSEntryFree(entry); + entry = tmp; + } + + VIR_FREE(grp); +} + + +void virNetServerMDNSEntryFree(virNetServerMDNSEntryPtr entry) +{ + if (!entry) + return; + + VIR_FREE(entry); +} + + diff --git a/src/rpc/virnetservermdns.h b/src/rpc/virnetservermdns.h new file mode 100644 index 0000000..6261aef --- /dev/null +++ b/src/rpc/virnetservermdns.h @@ -0,0 +1,109 @@ +/* + * virnetservermdns.c: advertise server sockets + * + * Copyright (C) 2011 Red Hat, Inc. + * Copyright (C) 2007 Daniel P. Berrange + * + * Derived from Avahi example service provider code. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#ifndef __VIR_NET_SERVER_MDNS_H__ +# define __VIR_NET_SERVER_MDNS_H__ + +#include "internal.h" + +typedef struct _virNetServerMDNS virNetServerMDNS; +typedef virNetServerMDNS *virNetServerMDNSPtr; +typedef struct _virNetServerMDNSGroup virNetServerMDNSGroup; +typedef virNetServerMDNSGroup *virNetServerMDNSGroupPtr; +typedef struct _virNetServerMDNSEntry virNetServerMDNSEntry; +typedef virNetServerMDNSEntry *virNetServerMDNSEntryPtr; + + +/** + * Prepares a new mdns manager object for use + */ +virNetServerMDNSPtr virNetServerMDNSNew(void); + +/** + * Starts the mdns client, advertising any groups/entries currently registered + * + * @mdns: manager to start advertising + * + * Starts the mdns client. Services may not be immediately visible, since + * it may asynchronously wait for the mdns service to startup + * + * returns -1 upon failure, 0 upon success. + */ +int virNetServerMDNSStart(virNetServerMDNSPtr mdns); + +/** + * Stops the mdns client, removing any advertisements + * + * @mdns: manager to start advertising + * + */ +void virNetServerMDNSStop(virNetServerMDNSPtr mdns); + +/** + * Adds a group container for advertisement + * + * @mdns manager to attach the group to + * @name unique human readable service name + * + * returns the group record, or NULL upon failure + */ +virNetServerMDNSGroupPtr virNetServerMDNSAddGroup(virNetServerMDNSPtr mdns, + const char *name); + +/** + * Removes a group container from advertisement + * + * @mdns amanger to detach group from + * @group group to remove + */ +void virNetServerMDNSRemoveGroup(virNetServerMDNSPtr mdns, + virNetServerMDNSGroupPtr group); + +/** + * Adds a service entry in a group + * + * @group group to attach the entry to + * @type service type string + * @port tcp port number + * + * returns the service record, or NULL upon failure + */ +virNetServerMDNSEntryPtr virNetServerMDNSAddEntry(virNetServerMDNSGroupPtr group, + const char *type, int port); + +/** + * Removes a service entry from a group + * + * @group group to detach service entry from + * @entry service entry to remove + */ +void virNetServerMDNSRemoveEntry(virNetServerMDNSGroupPtr group, + virNetServerMDNSEntryPtr entry); + +void virNetServerMDNSFree(virNetServerMDNSPtr ptr); +void virNetServerMDNSGroupFree(virNetServerMDNSGroupPtr ptr); +void virNetServerMDNSEntryFree(virNetServerMDNSEntryPtr ptr); + +#endif /* __VIR_NET_SERVER_MDNS_H__ */ diff --git a/src/rpc/virnetserverservice.c b/src/rpc/virnetserverservice.c index 0cc65c3..e5a47b0 100644 --- a/src/rpc/virnetserverservice.c +++ b/src/rpc/virnetserverservice.c @@ -187,6 +187,14 @@ error: } +int virNetServerServiceGetPort(virNetServerServicePtr svc) +{ + /* We're assuming if there are multiple sockets + * for IPv4 & 6, then they are all on same port */ + return virNetSocketGetPort(svc->socks[0]); +} + + int virNetServerServiceGetAuth(virNetServerServicePtr svc) { return svc->auth; diff --git a/src/rpc/virnetserverservice.h b/src/rpc/virnetserverservice.h index b8ccd55..378fa0b 100644 --- a/src/rpc/virnetserverservice.h +++ b/src/rpc/virnetserverservice.h @@ -48,6 +48,8 @@ virNetServerServicePtr virNetServerServiceNewUNIX(const char *path, bool readonly, virNetTLSContextPtr tls); +int virNetServerServiceGetPort(virNetServerServicePtr svc); + int virNetServerServiceGetAuth(virNetServerServicePtr svc); bool virNetServerServiceIsReadonly(virNetServerServicePtr svc); diff --git a/src/rpc/virnetsocket.c b/src/rpc/virnetsocket.c index 60ee4e5..06855da 100644 --- a/src/rpc/virnetsocket.c +++ b/src/rpc/virnetsocket.c @@ -694,6 +694,12 @@ bool virNetSocketIsLocal(virNetSocketPtr sock) } +int virNetSocketGetPort(virNetSocketPtr sock) +{ + return virSocketGetPort(&sock->localAddr); +} + + #ifdef SO_PEERCRED int virNetSocketGetLocalIdentity(virNetSocketPtr sock, uid_t *uid, diff --git a/src/rpc/virnetsocket.h b/src/rpc/virnetsocket.h index 59ff288..356d6c6 100644 --- a/src/rpc/virnetsocket.h +++ b/src/rpc/virnetsocket.h @@ -77,6 +77,8 @@ int virNetSocketNewConnectExternal(const char **cmdargv, int virNetSocketGetFD(virNetSocketPtr sock); bool virNetSocketIsLocal(virNetSocketPtr sock); +int virNetSocketGetPort(virNetSocketPtr sock); + int virNetSocketGetLocalIdentity(virNetSocketPtr sock, uid_t *uid, pid_t *pid); -- 1.7.4.4

To facilitate creation of new clients using XDR RPC services, pull alot of the remote driver code into a set of reusable objects. - virNetClient: Encapsulates a socket connection to a remote RPC server. Handles all the network I/O for reading/writing RPC messages. Delegates RPC encoding and decoding to the registered programs - virNetClientProgram: Handles processing and dispatch of RPC messages for a single RPC (program,version). A program can register to receive async events from a client - virNetClientStream: Handles generic I/O stream integration to RPC layer Each new client program now merely needs to define the list of RPC procedures & events it wants and their handlers. It does not need to deal with any of the network I/O functionality at all. --- cfg.mk | 3 + po/POTFILES.in | 2 + src/Makefile.am | 14 +- src/rpc/virnetclient.c | 1165 +++++++++++++++++++++++++++++++++++++++++ src/rpc/virnetclient.h | 86 +++ src/rpc/virnetclientprogram.c | 339 ++++++++++++ src/rpc/virnetclientprogram.h | 85 +++ src/rpc/virnetclientstream.c | 442 ++++++++++++++++ src/rpc/virnetclientstream.h | 76 +++ 9 files changed, 2211 insertions(+), 1 deletions(-) create mode 100644 src/rpc/virnetclient.c create mode 100644 src/rpc/virnetclient.h create mode 100644 src/rpc/virnetclientprogram.c create mode 100644 src/rpc/virnetclientprogram.h create mode 100644 src/rpc/virnetclientstream.c create mode 100644 src/rpc/virnetclientstream.h diff --git a/cfg.mk b/cfg.mk index 87be171..14da397 100644 --- a/cfg.mk +++ b/cfg.mk @@ -126,6 +126,9 @@ useless_free_options = \ --name=virJSONValueFree \ --name=virLastErrFreeData \ --name=virNetMessageFree \ + --name=virNetClientFree \ + --name=virNetClientProgramFree \ + --name=virNetClientStreamFree \ --name=virNetServerFree \ --name=virNetServerClientFree \ --name=virNetServerMDNSFree \ diff --git a/po/POTFILES.in b/po/POTFILES.in index 8a0e89f..356f415 100644 --- a/po/POTFILES.in +++ b/po/POTFILES.in @@ -67,6 +67,8 @@ src/qemu/qemu_monitor_text.c src/qemu/qemu_process.c src/remote/remote_client_bodies.h src/remote/remote_driver.c +src/rpc/virnetclient.c +src/rpc/virnetclientprogram.c src/rpc/virnetmessage.c src/rpc/virnetsaslcontext.c src/rpc/virnetsocket.c diff --git a/src/Makefile.am b/src/Makefile.am index 298759e..e6f2a02 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -1154,7 +1154,7 @@ libvirt_qemu_la_LIBADD = libvirt.la $(CYGWIN_EXTRA_LIBADD) EXTRA_DIST += $(LIBVIRT_QEMU_SYMBOL_FILE) -noinst_LTLIBRARIES += libvirt-net-rpc.la libvirt-net-rpc-server.la +noinst_LTLIBRARIES += libvirt-net-rpc.la libvirt-net-rpc-server.la libvirt-net-rpc-client.la libvirt_net_rpc_la_SOURCES = \ rpc/virnetmessage.h rpc/virnetmessage.c \ @@ -1204,6 +1204,18 @@ libvirt_net_rpc_server_la_LDFLAGS = \ libvirt_net_rpc_server_la_LIBADD = \ $(CYGWIN_EXTRA_LIBADD) +libvirt_net_rpc_client_la_SOURCES = \ + rpc/virnetclientprogram.h rpc/virnetclientprogram.c \ + rpc/virnetclientstream.h rpc/virnetclientstream.c \ + rpc/virnetclient.h rpc/virnetclient.c +libvirt_net_rpc_client_la_CFLAGS = \ + $(AM_CFLAGS) +libvirt_net_rpc_client_la_LDFLAGS = \ + $(AM_LDFLAGS) \ + $(CYGWIN_EXTRA_LDFLAGS) \ + $(MINGW_EXTRA_LDFLAGS)l +libvirt_net_rpc_client_la_LIBADD = \ + $(CYGWIN_EXTRA_LIBADD) libexec_PROGRAMS = diff --git a/src/rpc/virnetclient.c b/src/rpc/virnetclient.c new file mode 100644 index 0000000..d0c6884 --- /dev/null +++ b/src/rpc/virnetclient.c @@ -0,0 +1,1165 @@ +/* + * virnetclient.c: generic network RPC client + * + * Copyright (C) 2006-2011 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#include <config.h> + +#include <unistd.h> +#include <poll.h> +#include <signal.h> + +#include "virnetclient.h" +#include "virnetsocket.h" +#include "memory.h" +#include "threads.h" +#include "files.h" +#include "logging.h" +#include "util.h" +#include "virterror_internal.h" + +#define VIR_FROM_THIS VIR_FROM_RPC +#define virNetError(code, ...) \ + virReportErrorHelper(VIR_FROM_THIS, code, __FILE__, \ + __FUNCTION__, __LINE__, __VA_ARGS__) + +typedef struct _virNetClientCall virNetClientCall; +typedef virNetClientCall *virNetClientCallPtr; + +enum { + VIR_NET_CLIENT_MODE_WAIT_TX, + VIR_NET_CLIENT_MODE_WAIT_RX, + VIR_NET_CLIENT_MODE_COMPLETE, +}; + +struct _virNetClientCall { + int mode; + + virNetMessagePtr msg; + int expectReply; + + virCond cond; + + virNetClientCallPtr next; +}; + + +struct _virNetClient { + int refs; + + virMutex lock; + + virNetSocketPtr sock; + + virNetTLSSessionPtr tls; + char *hostname; + + virNetClientProgramPtr *programs; + size_t nprograms; + + /* For incoming message packets */ + virNetMessage msg; + +#if HAVE_SASL + virNetSASLSessionPtr sasl; +#endif + + /* Self-pipe to wakeup threads waiting in poll() */ + int wakeupSendFD; + int wakeupReadFD; + + /* List of threads currently waiting for dispatch */ + virNetClientCallPtr waitDispatch; + + size_t nstreams; + virNetClientStreamPtr *streams; +}; + + +static void virNetClientLock(virNetClientPtr client) +{ + virMutexLock(&client->lock); +} + + +static void virNetClientUnlock(virNetClientPtr client) +{ + virMutexUnlock(&client->lock); +} + + +static void virNetClientIncomingEvent(virNetSocketPtr sock, + int events, + void *opaque); + +static virNetClientPtr virNetClientNew(virNetSocketPtr sock, + const char *hostname) +{ + virNetClientPtr client; + int wakeupFD[2] = { -1, -1 }; + + if (pipe(wakeupFD) < 0) { + virReportSystemError(errno, "%s", + _("unable to make pipe")); + goto error; + } + + if (VIR_ALLOC(client) < 0) + goto no_memory; + + client->refs = 1; + + if (virMutexInit(&client->lock) < 0) + goto error; + + client->sock = sock; + client->wakeupReadFD = wakeupFD[0]; + client->wakeupSendFD = wakeupFD[1]; + wakeupFD[0] = wakeupFD[1] = -1; + + if (hostname && + !(client->hostname = strdup(hostname))) + goto no_memory; + + /* Set up a callback to listen on the socket data */ + if (virNetSocketAddIOCallback(client->sock, + VIR_EVENT_HANDLE_READABLE, + virNetClientIncomingEvent, + client) < 0) + VIR_DEBUG("Failed to add event watch, disabling events"); + + return client; + +no_memory: + virReportOOMError(); +error: + VIR_FORCE_CLOSE(wakeupFD[0]); + VIR_FORCE_CLOSE(wakeupFD[1]); + virNetClientFree(client); + return NULL; +} + + +virNetClientPtr virNetClientNewUNIX(const char *path, + bool spawnDaemon, + const char *binary) +{ + virNetSocketPtr sock; + + if (virNetSocketNewConnectUNIX(path, spawnDaemon, binary, &sock) < 0) + return NULL; + + return virNetClientNew(sock, NULL); +} + + +virNetClientPtr virNetClientNewTCP(const char *nodename, + const char *service) +{ + virNetSocketPtr sock; + + if (virNetSocketNewConnectTCP(nodename, service, &sock) < 0) + return NULL; + + return virNetClientNew(sock, nodename); +} + +virNetClientPtr virNetClientNewSSH(const char *nodename, + const char *service, + const char *binary, + const char *username, + bool noTTY, + const char *netcat, + const char *path) +{ + virNetSocketPtr sock; + + if (virNetSocketNewConnectSSH(nodename, service, binary, username, noTTY, netcat, path, &sock) < 0) + return NULL; + + return virNetClientNew(sock, NULL); +} + +virNetClientPtr virNetClientNewExternal(const char **cmdargv) +{ + virNetSocketPtr sock; + + if (virNetSocketNewConnectExternal(cmdargv, &sock) < 0) + return NULL; + + return virNetClientNew(sock, NULL); +} + + +void virNetClientRef(virNetClientPtr client) +{ + virNetClientLock(client); + client->refs++; + virNetClientUnlock(client); +} + + +void virNetClientFree(virNetClientPtr client) +{ + int i; + + if (!client) + return; + + virNetClientLock(client); + client->refs--; + if (client->refs > 0) { + virNetClientUnlock(client); + return; + } + + for (i = 0 ; i < client->nprograms ; i++) + virNetClientProgramFree(client->programs[i]); + VIR_FREE(client->programs); + + VIR_FORCE_CLOSE(client->wakeupSendFD); + VIR_FORCE_CLOSE(client->wakeupReadFD); + + VIR_FREE(client->hostname); + + virNetSocketRemoveIOCallback(client->sock); + virNetSocketFree(client->sock); + virNetTLSSessionFree(client->tls); +#if HAVE_SASL + virNetSASLSessionFree(client->sasl); +#endif + virNetClientUnlock(client); + virMutexDestroy(&client->lock); + + VIR_FREE(client); +} + + +#if HAVE_SASL +void virNetClientSetSASLSession(virNetClientPtr client, + virNetSASLSessionPtr sasl) +{ + virNetClientLock(client); + client->sasl = sasl; + virNetSASLSessionRef(sasl); + virNetSocketSetSASLSession(client->sock, client->sasl); + virNetClientUnlock(client); +} +#endif + + +int virNetClientSetTLSSession(virNetClientPtr client, + virNetTLSContextPtr tls) +{ + int ret; + char buf[1]; + int len; + struct pollfd fds[1]; +#ifdef HAVE_PTHREAD_SIGMASK + sigset_t oldmask, blockedsigs; + + sigemptyset (&blockedsigs); + sigaddset (&blockedsigs, SIGWINCH); + sigaddset (&blockedsigs, SIGCHLD); + sigaddset (&blockedsigs, SIGPIPE); +#endif + + virNetClientLock(client); + + if (!(client->tls = virNetTLSSessionNew(tls, + client->hostname))) + goto error; + + virNetSocketSetTLSSession(client->sock, client->tls); + + for (;;) { + ret = virNetTLSSessionHandshake(client->tls); + + if (ret < 0) + goto error; + if (ret == 0) + break; + + fds[0].fd = virNetSocketGetFD(client->sock); + fds[0].revents = 0; + if (virNetTLSSessionGetHandshakeStatus(client->tls) == + VIR_NET_TLS_HANDSHAKE_RECVING) + fds[0].events = POLLIN; + else + fds[0].events = POLLOUT; + + /* Block SIGWINCH from interrupting poll in curses programs, + * then restore the original signal mask again immediately + * after the call (RHBZ#567931). Same for SIGCHLD and SIGPIPE + * at the suggestion of Paolo Bonzini and Daniel Berrange. + */ +#ifdef HAVE_PTHREAD_SIGMASK + ignore_value(pthread_sigmask(SIG_BLOCK, &blockedsigs, &oldmask)); +#endif + + repoll: + ret = poll(fds, ARRAY_CARDINALITY(fds), -1); + if (ret < 0 && errno == EAGAIN) + goto repoll; + +#ifdef HAVE_PTHREAD_SIGMASK + ignore_value(pthread_sigmask(SIG_BLOCK, &oldmask, NULL)); +#endif + } + + ret = virNetTLSContextCheckCertificate(tls, client->tls); + + if (ret < 0) + goto error; + + /* At this point, the server is verifying _our_ certificate, IP address, + * etc. If we make the grade, it will send us a '\1' byte. + */ + + fds[0].fd = virNetSocketGetFD(client->sock); + fds[0].revents = 0; + fds[0].events = POLLIN; + +#ifdef HAVE_PTHREAD_SIGMASK + /* Block SIGWINCH from interrupting poll in curses programs */ + ignore_value(pthread_sigmask(SIG_BLOCK, &blockedsigs, &oldmask)); +#endif + + repoll2: + ret = poll(fds, ARRAY_CARDINALITY(fds), -1); + if (ret < 0 && errno == EAGAIN) + goto repoll2; + +#ifdef HAVE_PTHREAD_SIGMASK + ignore_value(pthread_sigmask(SIG_BLOCK, &oldmask, NULL)); +#endif + + len = virNetTLSSessionRead(client->tls, buf, 1); + if (len < 0) { + virReportSystemError(errno, "%s", + _("Unable to read TLS confirmation")); + goto error; + } + if (len != 1 || buf[0] != '\1') { + virNetError(VIR_ERR_RPC, "%s", + _("server verification (of our certificate or IP " + "address) failed")); + goto error; + } + + virNetClientUnlock(client); + return 0; + +error: + virNetTLSSessionFree(client->tls); + client->tls = NULL; + virNetClientUnlock(client); + return -1; +} + +bool virNetClientIsEncrypted(virNetClientPtr client) +{ + bool ret = false; + virNetClientLock(client); + if (client->tls) + ret = true; +#if HAVE_SASL + if (client->sasl) + ret = true; +#endif + virNetClientUnlock(client); + return ret; +} + + +int virNetClientAddProgram(virNetClientPtr client, + virNetClientProgramPtr prog) +{ + virNetClientLock(client); + + if (VIR_EXPAND_N(client->programs, client->nprograms, 1) < 0) + goto no_memory; + + client->programs[client->nprograms-1] = prog; + virNetClientProgramRef(prog); + + virNetClientUnlock(client); + return 0; + +no_memory: + virReportOOMError(); + virNetClientUnlock(client); + return -1; +} + + +int virNetClientAddStream(virNetClientPtr client, + virNetClientStreamPtr st) +{ + virNetClientLock(client); + + if (VIR_EXPAND_N(client->streams, client->nstreams, 1) < 0) + goto no_memory; + + client->streams[client->nstreams-1] = st; + virNetClientStreamRef(st); + + virNetClientUnlock(client); + return 0; + +no_memory: + virReportOOMError(); + virNetClientUnlock(client); + return -1; +} + + +void virNetClientRemoveStream(virNetClientPtr client, + virNetClientStreamPtr st) +{ + virNetClientLock(client); + size_t i; + for (i = 0 ; i < client->nstreams ; i++) { + if (client->streams[i] == st) + break; + } + if (i == client->nstreams) + goto cleanup; + + if (client->nstreams > 1) { + memmove(client->streams + i, + client->streams + i + 1, + sizeof(*client->streams) * + (client->nstreams - (i + 1))); + VIR_SHRINK_N(client->streams, client->nstreams, 1); + } else { + VIR_FREE(client->streams); + client->nstreams = 0; + } + virNetClientStreamFree(st); + +cleanup: + virNetClientUnlock(client); +} + + +const char *virNetClientLocalAddrString(virNetClientPtr client) +{ + return virNetSocketLocalAddrString(client->sock); +} + +const char *virNetClientRemoteAddrString(virNetClientPtr client) +{ + return virNetSocketRemoteAddrString(client->sock); +} + +int virNetClientGetTLSKeySize(virNetClientPtr client) +{ + int ret = 0; + virNetClientLock(client); + if (client->tls) + ret = virNetTLSSessionGetKeySize(client->tls); + virNetClientUnlock(client); + return ret; +} + +static int +virNetClientCallDispatchReply(virNetClientPtr client) +{ + virNetClientCallPtr thecall; + + /* Ok, definitely got an RPC reply now find + out who's been waiting for it */ + thecall = client->waitDispatch; + while (thecall && + !(thecall->msg->header.prog == client->msg.header.prog && + thecall->msg->header.vers == client->msg.header.vers && + thecall->msg->header.serial == client->msg.header.serial)) + thecall = thecall->next; + + if (!thecall) { + virNetError(VIR_ERR_RPC, + _("no call waiting for reply with prog %d vers %d serial %d"), + client->msg.header.prog, client->msg.header.vers, client->msg.header.serial); + return -1; + } + + memcpy(thecall->msg->buffer, client->msg.buffer, sizeof(client->msg.buffer)); + memcpy(&thecall->msg->header, &client->msg.header, sizeof(client->msg.header)); + thecall->msg->bufferLength = client->msg.bufferLength; + thecall->msg->bufferOffset = client->msg.bufferOffset; + + thecall->mode = VIR_NET_CLIENT_MODE_COMPLETE; + + return 0; +} + +static int virNetClientCallDispatchMessage(virNetClientPtr client) +{ + size_t i; + virNetClientProgramPtr prog = NULL; + + for (i = 0 ; i < client->nprograms ; i++) { + if (virNetClientProgramMatches(client->programs[i], + &client->msg)) { + prog = client->programs[i]; + break; + } + } + if (!prog) { + VIR_DEBUG("No program found for event with prog=%d vers=%d", + client->msg.header.prog, client->msg.header.vers); + return -1; + } + + virNetClientProgramDispatch(prog, client, &client->msg); + + return 0; +} + +static int virNetClientCallDispatchStream(virNetClientPtr client) +{ + size_t i; + virNetClientStreamPtr st = NULL; + virNetClientCallPtr thecall; + + /* First identify what stream this packet is directed at */ + for (i = 0 ; i < client->nstreams ; i++) { + if (virNetClientStreamMatches(client->streams[i], + &client->msg)) { + st = client->streams[i]; + break; + } + } + if (!st) { + VIR_DEBUG("No stream found for packet with prog=%d vers=%d serial=%u proc=%u", + client->msg.header.prog, client->msg.header.vers, + client->msg.header.serial, client->msg.header.proc); + return -1; + } + + /* Finish/Abort are synchronous, so also see if there's an + * (optional) call waiting for this stream packet */ + thecall = client->waitDispatch; + while (thecall && + !(thecall->msg->header.prog == client->msg.header.prog && + thecall->msg->header.vers == client->msg.header.vers && + thecall->msg->header.serial == client->msg.header.serial)) + thecall = thecall->next; + + VIR_DEBUG("Found call %p", thecall); + + /* Status is either + * - REMOTE_OK - no payload for streams + * - REMOTE_ERROR - followed by a remote_error struct + * - REMOTE_CONTINUE - followed by a raw data packet + */ + switch (client->msg.header.status) { + case VIR_NET_CONTINUE: { + if (virNetClientStreamQueuePacket(st, &client->msg) < 0) + return -1; + + if (thecall && thecall->expectReply) { + VIR_DEBUG("Got sync data packet completion"); + thecall->mode = VIR_NET_CLIENT_MODE_COMPLETE; + } else { + // XXX + //remoteStreamEventTimerUpdate(privst); + } + return 0; + } + + case VIR_NET_OK: + if (thecall && thecall->expectReply) { + VIR_DEBUG("Got a synchronous confirm"); + thecall->mode = VIR_NET_CLIENT_MODE_COMPLETE; + } else { + VIR_DEBUG("Got unexpected async stream finish confirmation"); + return -1; + } + return 0; + + case VIR_NET_ERROR: + /* No call, so queue the error against the stream */ + if (virNetClientStreamSetError(st, &client->msg) < 0) + return -1; + + if (thecall && thecall->expectReply) { + VIR_DEBUG("Got a synchronous error"); + /* Raise error now, so that this call will see it immediately */ + virNetClientStreamRaiseError(st); + thecall->mode = VIR_NET_CLIENT_MODE_COMPLETE; + } + return 0; + + default: + VIR_WARN("Stream with unexpected serial=%d, proc=%d, status=%d", + client->msg.header.serial, client->msg.header.proc, + client->msg.header.status); + return -1; + } + + return 0; +} + + +static int +virNetClientCallDispatch(virNetClientPtr client) +{ + if (virNetMessageDecodeHeader(&client->msg) < 0) + return -1; + + VIR_DEBUG("Incoming message prog %d vers %d proc %d type %d status %d serial %d", + client->msg.header.prog, client->msg.header.vers, + client->msg.header.proc, client->msg.header.type, + client->msg.header.status, client->msg.header.serial); + + switch (client->msg.header.type) { + case VIR_NET_REPLY: /* Normal RPC replies */ + return virNetClientCallDispatchReply(client); + + case VIR_NET_MESSAGE: /* Async notifications */ + return virNetClientCallDispatchMessage(client); + + case VIR_NET_STREAM: /* Stream protocol */ + return virNetClientCallDispatchStream(client); + + default: + virNetError(VIR_ERR_RPC, + _("got unexpected RPC call prog %d vers %d proc %d type %d"), + client->msg.header.prog, client->msg.header.vers, + client->msg.header.proc, client->msg.header.type); + return -1; + } +} + + +static ssize_t +virNetClientIOWriteMessage(virNetClientPtr client, + virNetClientCallPtr thecall) +{ + ssize_t ret; + + ret = virNetSocketWrite(client->sock, + thecall->msg->buffer + thecall->msg->bufferOffset, + thecall->msg->bufferLength - thecall->msg->bufferOffset); + if (ret <= 0) + return ret; + + thecall->msg->bufferOffset += ret; + + if (thecall->msg->bufferOffset == thecall->msg->bufferLength) { + thecall->msg->bufferOffset = thecall->msg->bufferLength = 0; + if (thecall->expectReply) + thecall->mode = VIR_NET_CLIENT_MODE_WAIT_RX; + else + thecall->mode = VIR_NET_CLIENT_MODE_COMPLETE; + } + + return ret; +} + + +static ssize_t +virNetClientIOHandleOutput(virNetClientPtr client) +{ + virNetClientCallPtr thecall = client->waitDispatch; + + while (thecall && + thecall->mode != VIR_NET_CLIENT_MODE_WAIT_TX) + thecall = thecall->next; + + if (!thecall) + return -1; /* Shouldn't happen, but you never know... */ + + while (thecall) { + ssize_t ret = virNetClientIOWriteMessage(client, thecall); + if (ret < 0) + return ret; + + if (thecall->mode == VIR_NET_CLIENT_MODE_WAIT_TX) + return 0; /* Blocking write, to back to event loop */ + + thecall = thecall->next; + } + + return 0; /* No more calls to send, all done */ +} + +static ssize_t +virNetClientIOReadMessage(virNetClientPtr client) +{ + size_t wantData; + ssize_t ret; + + /* Start by reading length word */ + if (client->msg.bufferLength == 0) + client->msg.bufferLength = 4; + + wantData = client->msg.bufferLength - client->msg.bufferOffset; + + ret = virNetSocketRead(client->sock, + client->msg.buffer + client->msg.bufferOffset, + wantData); + if (ret <= 0) + return ret; + + client->msg.bufferOffset += ret; + + return ret; +} + + +static ssize_t +virNetClientIOHandleInput(virNetClientPtr client) +{ + /* Read as much data as is available, until we get + * EAGAIN + */ + for (;;) { + ssize_t ret = virNetClientIOReadMessage(client); + + if (ret < 0) + return -1; + if (ret == 0) + return 0; /* Blocking on read */ + + /* Check for completion of our goal */ + if (client->msg.bufferOffset == client->msg.bufferLength) { + if (client->msg.bufferOffset == 4) { + ret = virNetMessageDecodeLength(&client->msg); + if (ret < 0) + return -1; + + /* + * We'll carry on around the loop to immediately + * process the message body, because it has probably + * already arrived. Worst case, we'll get EAGAIN on + * next iteration. + */ + } else { + ret = virNetClientCallDispatch(client); + client->msg.bufferOffset = client->msg.bufferLength = 0; + /* + * We've completed one call, but we don't want to + * spin around the loop forever if there are many + * incoming async events, or replies for other + * thread's RPC calls. We want to get out & let + * any other thread take over as soon as we've + * got our reply. When SASL is active though, we + * may have read more data off the wire than we + * initially wanted & cached it in memory. In this + * case, poll() would not detect that there is more + * ready todo. + * + * So if SASL is active *and* some SASL data is + * already cached, then we'll process that now, + * before returning. + */ + if (ret == 0 && + virNetSocketHasCachedData(client->sock)) + continue; + return ret; + } + } + } +} + + +/* + * Process all calls pending dispatch/receive until we + * get a reply to our own call. Then quit and pass the buck + * to someone else. + */ +static int virNetClientIOEventLoop(virNetClientPtr client, + virNetClientCallPtr thiscall) +{ + struct pollfd fds[2]; + int ret; + + fds[0].fd = virNetSocketGetFD(client->sock); + fds[1].fd = client->wakeupReadFD; + + for (;;) { + virNetClientCallPtr tmp = client->waitDispatch; + virNetClientCallPtr prev; + char ignore; +#ifdef HAVE_PTHREAD_SIGMASK + sigset_t oldmask, blockedsigs; +#endif + int timeout = -1; + + /* If we have existing SASL decoded data we + * don't want to sleep in the poll(), just + * check if any other FDs are also ready + */ + if (virNetSocketHasCachedData(client->sock)) + timeout = 0; + + fds[0].events = fds[0].revents = 0; + fds[1].events = fds[1].revents = 0; + + fds[1].events = POLLIN; + while (tmp) { + if (tmp->mode == VIR_NET_CLIENT_MODE_WAIT_RX) + fds[0].events |= POLLIN; + if (tmp->mode == VIR_NET_CLIENT_MODE_WAIT_TX) + fds[0].events |= POLLOUT; + + tmp = tmp->next; + } + + /* We have to be prepared to receive stream data + * regardless of whether any of the calls waiting + * for dispatch are for streams. + */ + if (client->nstreams) + fds[0].events |= POLLIN; + + /* Release lock while poll'ing so other threads + * can stuff themselves on the queue */ + virNetClientUnlock(client); + + /* Block SIGWINCH from interrupting poll in curses programs, + * then restore the original signal mask again immediately + * after the call (RHBZ#567931). Same for SIGCHLD and SIGPIPE + * at the suggestion of Paolo Bonzini and Daniel Berrange. + */ +#ifdef HAVE_PTHREAD_SIGMASK + sigemptyset (&blockedsigs); + sigaddset (&blockedsigs, SIGWINCH); + sigaddset (&blockedsigs, SIGCHLD); + sigaddset (&blockedsigs, SIGPIPE); + ignore_value(pthread_sigmask(SIG_BLOCK, &blockedsigs, &oldmask)); +#endif + + repoll: + ret = poll(fds, ARRAY_CARDINALITY(fds), timeout); + if (ret < 0 && errno == EAGAIN) + goto repoll; + +#ifdef HAVE_PTHREAD_SIGMASK + ignore_value(pthread_sigmask(SIG_SETMASK, &oldmask, NULL)); +#endif + + virNetClientLock(client); + + /* If we have existing SASL decoded data, pretend + * the socket became readable so we consume it + */ + if (virNetSocketHasCachedData(client->sock)) + fds[0].revents |= POLLIN; + + if (fds[1].revents) { + VIR_DEBUG("Woken up from poll by other thread"); + if (saferead(client->wakeupReadFD, &ignore, sizeof(ignore)) != sizeof(ignore)) { + virReportSystemError(errno, "%s", + _("read on wakeup fd failed")); + goto error; + } + } + + if (ret < 0) { + if (errno == EWOULDBLOCK) + continue; + virReportSystemError(errno, + "%s", _("poll on socket failed")); + goto error; + } + + if (fds[0].revents & POLLOUT) { + if (virNetClientIOHandleOutput(client) < 0) + goto error; + } + + if (fds[0].revents & POLLIN) { + if (virNetClientIOHandleInput(client) < 0) + goto error; + } + + /* Iterate through waiting threads and if + * any are complete then tell 'em to wakeup + */ + tmp = client->waitDispatch; + prev = NULL; + while (tmp) { + if (tmp != thiscall && + tmp->mode == VIR_NET_CLIENT_MODE_COMPLETE) { + /* Take them out of the list */ + if (prev) + prev->next = tmp->next; + else + client->waitDispatch = tmp->next; + + /* And wake them up.... + * ...they won't actually wakeup until + * we release our mutex a short while + * later... + */ + VIR_DEBUG("Waking up sleep %p %p", tmp, client->waitDispatch); + virCondSignal(&tmp->cond); + } else { + prev = tmp; + } + tmp = tmp->next; + } + + /* Now see if *we* are done */ + if (thiscall->mode == VIR_NET_CLIENT_MODE_COMPLETE) { + /* We're at head of the list already, so + * remove us + */ + client->waitDispatch = thiscall->next; + VIR_DEBUG("Giving up the buck %p %p", thiscall, client->waitDispatch); + /* See if someone else is still waiting + * and if so, then pass the buck ! */ + if (client->waitDispatch) { + VIR_DEBUG("Passing the buck to %p", client->waitDispatch); + virCondSignal(&client->waitDispatch->cond); + } + return 0; + } + + + if (fds[0].revents & (POLLHUP | POLLERR)) { + virNetError(VIR_ERR_INTERNAL_ERROR, "%s", + _("received hangup / error event on socket")); + goto error; + } + } + + +error: + client->waitDispatch = thiscall->next; + VIR_DEBUG("Giving up the buck due to I/O error %p %p", thiscall, client->waitDispatch); + /* See if someone else is still waiting + * and if so, then pass the buck ! */ + if (client->waitDispatch) { + VIR_DEBUG("Passing the buck to %p", client->waitDispatch); + virCondSignal(&client->waitDispatch->cond); + } + return -1; +} + + +/* + * This function sends a message to remote server and awaits a reply + * + * NB. This does not free the args structure (not desirable, since you + * often want this allocated on the stack or else it contains strings + * which come from the user). It does however free any intermediate + * results, eg. the error structure if there is one. + * + * NB(2). Make sure to memset (&ret, 0, sizeof ret) before calling, + * else Bad Things will happen in the XDR code. + * + * NB(3) You must have the client lock before calling this + * + * NB(4) This is very complicated. Multiple threads are allowed to + * use the client for RPC at the same time. Obviously only one of + * them can. So if someone's using the socket, other threads are put + * to sleep on condition variables. The existing thread may completely + * send & receive their RPC call/reply while they're asleep. Or it + * may only get around to dealing with sending the call. Or it may + * get around to neither. So upon waking up from slumber, the other + * thread may or may not have more work todo. + * + * We call this dance 'passing the buck' + * + * http://en.wikipedia.org/wiki/Passing_the_buck + * + * "Buck passing or passing the buck is the action of transferring + * responsibility or blame unto another person. It is also used as + * a strategy in power politics when the actions of one country/ + * nation are blamed on another, providing an opportunity for war." + * + * NB(5) Don't Panic! + */ +static int virNetClientIO(virNetClientPtr client, + virNetClientCallPtr thiscall) +{ + int rv = -1; + + VIR_DEBUG("program=%u version=%u serial=%u proc=%d type=%d length=%zu dispatch=%p", + thiscall->msg->header.prog, + thiscall->msg->header.vers, + thiscall->msg->header.serial, + thiscall->msg->header.proc, + thiscall->msg->header.type, + thiscall->msg->bufferLength, + client->waitDispatch); + + /* Check to see if another thread is dispatching */ + if (client->waitDispatch) { + /* Stick ourselves on the end of the wait queue */ + virNetClientCallPtr tmp = client->waitDispatch; + char ignore = 1; + while (tmp && tmp->next) + tmp = tmp->next; + if (tmp) + tmp->next = thiscall; + else + client->waitDispatch = thiscall; + + /* Force other thread to wakeup from poll */ + if (safewrite(client->wakeupSendFD, &ignore, sizeof(ignore)) != sizeof(ignore)) { + if (tmp) + tmp->next = NULL; + else + client->waitDispatch = NULL; + virReportSystemError(errno, "%s", + _("failed to wake up polling thread")); + return -1; + } + + VIR_DEBUG("Going to sleep %p %p", client->waitDispatch, thiscall); + /* Go to sleep while other thread is working... */ + if (virCondWait(&thiscall->cond, &client->lock) < 0) { + if (client->waitDispatch == thiscall) { + client->waitDispatch = thiscall->next; + } else { + tmp = client->waitDispatch; + while (tmp && tmp->next && + tmp->next != thiscall) { + tmp = tmp->next; + } + if (tmp && tmp->next == thiscall) + tmp->next = thiscall->next; + } + virNetError(VIR_ERR_INTERNAL_ERROR, "%s", + _("failed to wait on condition")); + return -1; + } + + VIR_DEBUG("Wokeup from sleep %p %p", client->waitDispatch, thiscall); + /* Two reasons we can be woken up + * 1. Other thread has got our reply ready for us + * 2. Other thread is all done, and it is our turn to + * be the dispatcher to finish waiting for + * our reply + */ + if (thiscall->mode == VIR_NET_CLIENT_MODE_COMPLETE) { + rv = 0; + /* + * We avoided catching the buck and our reply is ready ! + * We've already had 'thiscall' removed from the list + * so just need to (maybe) handle errors & free it + */ + goto cleanup; + } + + /* Grr, someone passed the buck onto us ... */ + + } else { + /* We're first to catch the buck */ + client->waitDispatch = thiscall; + } + + VIR_DEBUG("We have the buck %p %p", client->waitDispatch, thiscall); + /* + * The buck stops here! + * + * At this point we're about to own the dispatch + * process... + */ + + /* + * Avoid needless wake-ups of the event loop in the + * case where this call is being made from a different + * thread than the event loop. These wake-ups would + * cause the event loop thread to be blocked on the + * mutex for the duration of the call + */ + virNetSocketUpdateIOCallback(client->sock, 0); + + rv = virNetClientIOEventLoop(client, thiscall); + + virNetSocketUpdateIOCallback(client->sock, VIR_EVENT_HANDLE_READABLE); + +cleanup: + VIR_DEBUG("All done with our call %p %p %d", client->waitDispatch, thiscall, rv); + return rv; +} + + +void virNetClientIncomingEvent(virNetSocketPtr sock, + int events, + void *opaque) +{ + virNetClientPtr client = opaque; + + virNetClientLock(client); + + /* This should be impossible, but it doesn't hurt to check */ + if (client->waitDispatch) + goto done; + + VIR_DEBUG("Event fired %p %d", sock, events); + + if (events & (VIR_EVENT_HANDLE_HANGUP | VIR_EVENT_HANDLE_ERROR)) { + VIR_DEBUG("%s : VIR_EVENT_HANDLE_HANGUP or " + "VIR_EVENT_HANDLE_ERROR encountered", __FUNCTION__); + virNetSocketRemoveIOCallback(sock); + goto done; + } + + if (virNetClientIOHandleInput(client) < 0) + VIR_DEBUG("Something went wrong during async message processing"); + +done: + virNetClientUnlock(client); +} + + +int virNetClientSend(virNetClientPtr client, + virNetMessagePtr msg, + bool expectReply) +{ + virNetClientCallPtr call; + int ret = -1; + + if (VIR_ALLOC(call) < 0) { + virReportOOMError(); + return -1; + } + + virNetClientLock(client); + + if (virCondInit(&call->cond) < 0) { + virNetError(VIR_ERR_INTERNAL_ERROR, "%s", + _("cannot initialize condition variable")); + goto cleanup; + } + + if (msg->bufferLength) + call->mode = VIR_NET_CLIENT_MODE_WAIT_TX; + else + call->mode = VIR_NET_CLIENT_MODE_WAIT_RX; + call->msg = msg; + call->expectReply = expectReply; + + ret = virNetClientIO(client, call); + +cleanup: + ignore_value(virCondDestroy(&call->cond)); + VIR_FREE(call); + virNetClientUnlock(client); + return ret; +} diff --git a/src/rpc/virnetclient.h b/src/rpc/virnetclient.h new file mode 100644 index 0000000..bbe5a69 --- /dev/null +++ b/src/rpc/virnetclient.h @@ -0,0 +1,86 @@ +/* + * virnetclient.h: generic network RPC client + * + * Copyright (C) 2006-2011 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#ifndef __VIR_NET_CLIENT_H__ +# define __VIR_NET_CLIENT_H__ + +# include <stdbool.h> + +# include "virnettlscontext.h" +# include "virnetmessage.h" +# ifdef HAVE_SASL +# include "virnetsaslcontext.h" +# endif +# include "virnetclientprogram.h" +# include "virnetclientstream.h" + + +virNetClientPtr virNetClientNewUNIX(const char *path, + bool spawnDaemon, + const char *daemon); + +virNetClientPtr virNetClientNewTCP(const char *nodename, + const char *service); + +virNetClientPtr virNetClientNewSSH(const char *nodename, + const char *service, + const char *binary, + const char *username, + bool noTTY, + const char *netcat, + const char *path); + +virNetClientPtr virNetClientNewExternal(const char **cmdargv); + +void virNetClientRef(virNetClientPtr client); + +int virNetClientAddProgram(virNetClientPtr client, + virNetClientProgramPtr prog); + +int virNetClientAddStream(virNetClientPtr client, + virNetClientStreamPtr st); + +void virNetClientRemoveStream(virNetClientPtr client, + virNetClientStreamPtr st); + +int virNetClientSend(virNetClientPtr client, + virNetMessagePtr msg, + bool expectReply); + +# ifdef HAVE_SASL +void virNetClientSetSASLSession(virNetClientPtr client, + virNetSASLSessionPtr sasl); +# endif + +int virNetClientSetTLSSession(virNetClientPtr client, + virNetTLSContextPtr tls); + +bool virNetClientIsEncrypted(virNetClientPtr client); + +const char *virNetClientLocalAddrString(virNetClientPtr client); +const char *virNetClientRemoteAddrString(virNetClientPtr client); + +int virNetClientGetTLSKeySize(virNetClientPtr client); + +void virNetClientFree(virNetClientPtr client); + +#endif /* __VIR_NET_CLIENT_H__ */ diff --git a/src/rpc/virnetclientprogram.c b/src/rpc/virnetclientprogram.c new file mode 100644 index 0000000..8414ad8 --- /dev/null +++ b/src/rpc/virnetclientprogram.c @@ -0,0 +1,339 @@ +/* + * virnetclientprogram.c: generic network RPC client program + * + * Copyright (C) 2006-2011 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#include <config.h> + +#include "virnetclientprogram.h" +#include "virnetclient.h" +#include "virnetprotocol.h" + +#include "memory.h" +#include "virterror_internal.h" +#include "logging.h" + +#define VIR_FROM_THIS VIR_FROM_RPC +#define virNetError(code, ...) \ + virReportErrorHelper(VIR_FROM_THIS, code, __FILE__, \ + __FUNCTION__, __LINE__, __VA_ARGS__) + +struct _virNetClientProgram { + int refs; + + unsigned program; + unsigned version; + virNetClientProgramEventPtr events; + size_t nevents; + void *eventOpaque; +}; + +virNetClientProgramPtr virNetClientProgramNew(unsigned program, + unsigned version, + virNetClientProgramEventPtr events, + size_t nevents, + void *eventOpaque) +{ + virNetClientProgramPtr prog; + + if (VIR_ALLOC(prog) < 0) { + virReportOOMError(); + return NULL; + } + + prog->refs = 1; + prog->program = program; + prog->version = version; + prog->events = events; + prog->nevents = nevents; + prog->eventOpaque = eventOpaque; + + return prog; +} + + +void virNetClientProgramRef(virNetClientProgramPtr prog) +{ + prog->refs++; +} + + +void virNetClientProgramFree(virNetClientProgramPtr prog) +{ + if (!prog) + return; + + prog->refs--; + if (prog->refs > 0) + return; + + VIR_FREE(prog); +} + + +unsigned virNetClientProgramGetProgram(virNetClientProgramPtr prog) +{ + return prog->program; +} + + +unsigned virNetClientProgramGetVersion(virNetClientProgramPtr prog) +{ + return prog->version; +} + + +int virNetClientProgramMatches(virNetClientProgramPtr prog, + virNetMessagePtr msg) +{ + if (prog->program == msg->header.prog && + prog->version == msg->header.vers) + return 1; + return 0; +} + + +static int +virNetClientProgramDispatchError(virNetClientProgramPtr prog ATTRIBUTE_UNUSED, + virNetMessagePtr msg) +{ + virNetMessageError err; + int ret = -1; + + memset(&err, 0, sizeof(err)); + + if (virNetMessageDecodePayload(msg, (xdrproc_t)xdr_virNetMessageError, &err) < 0) + goto cleanup; + + /* Interop for virErrorNumber glitch in 0.8.0, if server is + * 0.7.1 through 0.7.7; see comments in virterror.h. */ + switch (err.code) { + case VIR_WAR_NO_NWFILTER: + /* no way to tell old VIR_WAR_NO_SECRET apart from + * VIR_WAR_NO_NWFILTER, but both are very similar + * warnings, so ignore the difference */ + break; + case VIR_ERR_INVALID_NWFILTER: + case VIR_ERR_NO_NWFILTER: + case VIR_ERR_BUILD_FIREWALL: + /* server was trying to pass VIR_ERR_INVALID_SECRET, + * VIR_ERR_NO_SECRET, or VIR_ERR_CONFIG_UNSUPPORTED */ + if (err.domain != VIR_FROM_NWFILTER) + err.code += 4; + break; + case VIR_WAR_NO_SECRET: + if (err.domain == VIR_FROM_QEMU) + err.code = VIR_ERR_OPERATION_TIMEOUT; + break; + case VIR_ERR_INVALID_SECRET: + if (err.domain == VIR_FROM_XEN) + err.code = VIR_ERR_MIGRATE_PERSIST_FAILED; + break; + default: + /* Nothing to alter. */ + break; + } + + if (err.domain == VIR_FROM_REMOTE && + err.code == VIR_ERR_RPC && + err.level == VIR_ERR_ERROR && + err.message && + STRPREFIX(*err.message, "unknown procedure")) { + virRaiseErrorFull(__FILE__, __FUNCTION__, __LINE__, + err.domain, + VIR_ERR_NO_SUPPORT, + err.level, + err.str1 ? *err.str1 : NULL, + err.str2 ? *err.str2 : NULL, + err.str3 ? *err.str3 : NULL, + err.int1, + err.int2, + "%s", *err.message); + } else { + virRaiseErrorFull(__FILE__, __FUNCTION__, __LINE__, + err.domain, + err.code, + err.level, + err.str1 ? *err.str1 : NULL, + err.str2 ? *err.str2 : NULL, + err.str3 ? *err.str3 : NULL, + err.int1, + err.int2, + "%s", err.message ? *err.message : _("Unknown error")); + } + + ret = 0; + +cleanup: + xdr_free((xdrproc_t)xdr_virNetMessageError, (void*)&err); + return ret; +} + + +static virNetClientProgramEventPtr virNetClientProgramGetEvent(virNetClientProgramPtr prog, + int procedure) +{ + int i; + + for (i = 0 ; i < prog->nevents ; i++) { + if (prog->events[i].proc == procedure) + return &prog->events[i]; + } + + return NULL; +} + + +int virNetClientProgramDispatch(virNetClientProgramPtr prog, + virNetClientPtr client, + virNetMessagePtr msg) +{ + virNetClientProgramEventPtr event; + char *evdata; + + VIR_DEBUG("prog=%d ver=%d type=%d status=%d serial=%d proc=%d", + msg->header.prog, msg->header.vers, msg->header.type, + msg->header.status, msg->header.serial, msg->header.proc); + + /* Check version, etc. */ + if (msg->header.prog != prog->program) { + VIR_ERROR(_("program mismatch in event (actual %x, expected %x)"), + msg->header.prog, prog->program); + return -1; + } + + if (msg->header.vers != prog->version) { + VIR_ERROR(_("version mismatch in event (actual %x, expected %x)"), + msg->header.vers, prog->version); + return -1; + } + + if (msg->header.status != VIR_NET_OK) { + VIR_ERROR(_("status mismatch in event (actual %x, expected %x)"), + msg->header.status, VIR_NET_OK); + return -1; + } + + if (msg->header.type != VIR_NET_MESSAGE) { + VIR_ERROR(_("type mismatch in event (actual %x, expected %x)"), + msg->header.type, VIR_NET_MESSAGE); + return -1; + } + + event = virNetClientProgramGetEvent(prog, msg->header.proc); + + if (!event) { + VIR_ERROR(_("No event expected with procedure %x"), + msg->header.proc); + return -1; + } + + if (VIR_ALLOC_N(evdata, event->msg_len) < 0) { + virReportOOMError(); + return -1; + } + + if (virNetMessageDecodePayload(msg, event->msg_filter, evdata) < 0) + goto cleanup; + + event->func(prog, client, evdata, prog->eventOpaque); + + xdr_free(event->msg_filter, evdata); + +cleanup: + VIR_FREE(evdata); + return 0; +} + + +int virNetClientProgramCall(virNetClientProgramPtr prog, + virNetClientPtr client, + unsigned serial, + int proc, + xdrproc_t args_filter, void *args, + xdrproc_t ret_filter, void *ret) +{ + virNetMessagePtr msg; + + if (!(msg = virNetMessageNew())) + return -1; + + msg->header.prog = prog->program; + msg->header.vers = prog->version; + msg->header.status = VIR_NET_OK; + msg->header.type = VIR_NET_CALL; + msg->header.serial = serial; + msg->header.proc = proc; + + if (virNetMessageEncodeHeader(msg) < 0) + goto error; + + if (virNetMessageEncodePayload(msg, args_filter, args) < 0) + goto error; + + if (virNetClientSend(client, msg, true) < 0) + goto error; + + /* None of these 3 should ever happen here, because + * virNetClientSend should have validated the reply, + * but it doesn't hurt to check again. + */ + if (msg->header.type != VIR_NET_REPLY) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("Unexpected message type %d"), msg->header.type); + goto error; + } + if (msg->header.proc != proc) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("Unexpected message proc %d != %d"), + msg->header.proc, proc); + goto error; + } + if (msg->header.serial != serial) { + virNetError(VIR_ERR_INTERNAL_ERROR, + _("Unexpected message serial %d != %d"), + msg->header.serial, serial); + goto error; + } + + switch (msg->header.status) { + case VIR_NET_OK: + if (virNetMessageDecodePayload(msg, ret_filter, ret) < 0) + goto error; + break; + + case VIR_NET_ERROR: + virNetClientProgramDispatchError(prog, msg); + goto error; + + default: + virNetError(VIR_ERR_RPC, + _("Unexpected message status %d"), msg->header.status); + goto error; + } + + VIR_FREE(msg); + + return 0; + +error: + VIR_FREE(msg); + return -1; +} diff --git a/src/rpc/virnetclientprogram.h b/src/rpc/virnetclientprogram.h new file mode 100644 index 0000000..82ae2c6 --- /dev/null +++ b/src/rpc/virnetclientprogram.h @@ -0,0 +1,85 @@ +/* + * virnetclientprogram.h: generic network RPC client program + * + * Copyright (C) 2006-2011 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#ifndef __VIR_NET_CLIENT_PROGRAM_H__ +# define __VIR_NET_CLIENT_PROGRAM_H__ + +# include <rpc/types.h> +# include <rpc/xdr.h> + +# include "virnetmessage.h" + +typedef struct _virNetClient virNetClient; +typedef virNetClient *virNetClientPtr; + +typedef struct _virNetClientProgram virNetClientProgram; +typedef virNetClientProgram *virNetClientProgramPtr; + +typedef struct _virNetClientProgramEvent virNetClientProgramEvent; +typedef virNetClientProgramEvent *virNetClientProgramEventPtr; + +typedef struct _virNetClientProgramErrorHandler virNetClientProgramErrorHander; +typedef virNetClientProgramErrorHander *virNetClientProgramErrorHanderPtr; + + +typedef void (*virNetClientProgramDispatchFunc)(virNetClientProgramPtr prog, + virNetClientPtr client, + void *msg, + void *opaque); + +struct _virNetClientProgramEvent { + int proc; + virNetClientProgramDispatchFunc func; + size_t msg_len; + xdrproc_t msg_filter; +}; + +virNetClientProgramPtr virNetClientProgramNew(unsigned program, + unsigned version, + virNetClientProgramEventPtr events, + size_t nevents, + void *eventOpaque); + +unsigned virNetClientProgramGetProgram(virNetClientProgramPtr prog); +unsigned virNetClientProgramGetVersion(virNetClientProgramPtr prog); + +void virNetClientProgramRef(virNetClientProgramPtr prog); + +void virNetClientProgramFree(virNetClientProgramPtr prog); + +int virNetClientProgramMatches(virNetClientProgramPtr prog, + virNetMessagePtr msg); + +int virNetClientProgramDispatch(virNetClientProgramPtr prog, + virNetClientPtr client, + virNetMessagePtr msg); + +int virNetClientProgramCall(virNetClientProgramPtr prog, + virNetClientPtr client, + unsigned serial, + int proc, + xdrproc_t args_filter, void *args, + xdrproc_t ret_filter, void *ret); + + + +#endif /* __VIR_NET_CLIENT_PROGRAM_H__ */ diff --git a/src/rpc/virnetclientstream.c b/src/rpc/virnetclientstream.c new file mode 100644 index 0000000..42836e4 --- /dev/null +++ b/src/rpc/virnetclientstream.c @@ -0,0 +1,442 @@ +/* + * virnetclientstream.c: generic network RPC client stream + * + * Copyright (C) 2006-2011 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#include <config.h> + +#include "virnetclientstream.h" +#include "virnetclient.h" +#include "memory.h" +#include "virterror_internal.h" +#include "logging.h" +#include "event.h" + +#define VIR_FROM_THIS VIR_FROM_RPC +#define virNetError(code, ...) \ + virReportErrorHelper(VIR_FROM_THIS, code, __FILE__, \ + __FUNCTION__, __LINE__, __VA_ARGS__) + +struct _virNetClientStream { + virNetClientProgramPtr prog; + int proc; + unsigned serial; + int refs; + + virError err; + + /* XXX this buffer is unbounded if the client + * app has domain events registered, since packets + * may be read off wire, while app isn't ready to + * recv them. Figure out how to address this some + * time by stopping consuming any incoming data + * off the socket.... + */ + char *incoming; + size_t incomingOffset; + size_t incomingLength; + + + virNetClientStreamEventCallback cb; + void *cbOpaque; + virFreeCallback cbFree; + int cbEvents; + int cbTimer; + int cbDispatch; +}; + + +static void +virNetClientStreamEventTimerUpdate(virNetClientStreamPtr st) +{ + if (!st->cb) + return; + + VIR_DEBUG("Check timer offset=%zu %d", st->incomingOffset, st->cbEvents); + + if ((st->incomingOffset && + (st->cbEvents & VIR_STREAM_EVENT_READABLE)) || + (st->cbEvents & VIR_STREAM_EVENT_WRITABLE)) { + VIR_DEBUG("Enabling event timer"); + virEventUpdateTimeout(st->cbTimer, 0); + } else { + VIR_DEBUG("Disabling event timer"); + virEventUpdateTimeout(st->cbTimer, -1); + } +} + + +static void +virNetClientStreamEventTimer(int timer ATTRIBUTE_UNUSED, void *opaque) +{ + virNetClientStreamPtr st = opaque; + int events = 0; + + /* XXX we need a mutex on 'st' to protect this callback */ + + if (st->cb && + (st->cbEvents & VIR_STREAM_EVENT_READABLE) && + st->incomingOffset) + events |= VIR_STREAM_EVENT_READABLE; + if (st->cb && + (st->cbEvents & VIR_STREAM_EVENT_WRITABLE)) + events |= VIR_STREAM_EVENT_WRITABLE; + + VIR_DEBUG("Got Timer dispatch %d %d offset=%zu", events, st->cbEvents, st->incomingOffset); + if (events) { + virNetClientStreamEventCallback cb = st->cb; + void *cbOpaque = st->cbOpaque; + virFreeCallback cbFree = st->cbFree; + + st->cbDispatch = 1; + (cb)(st, events, cbOpaque); + st->cbDispatch = 0; + + if (!st->cb && cbFree) + (cbFree)(cbOpaque); + } +} + + +static void +virNetClientStreamEventTimerFree(void *opaque) +{ + virNetClientStreamPtr st = opaque; + virNetClientStreamFree(st); +} + + +virNetClientStreamPtr virNetClientStreamNew(virNetClientProgramPtr prog, + int proc, + unsigned serial) +{ + virNetClientStreamPtr st; + + if (VIR_ALLOC(st) < 0) { + virReportOOMError(); + return NULL; + } + + virNetClientProgramRef(prog); + + st->refs = 1; + st->prog = prog; + st->proc = proc; + st->serial = serial; + + return st; +} + + +void virNetClientStreamRef(virNetClientStreamPtr st) +{ + st->refs++; +} + +void virNetClientStreamFree(virNetClientStreamPtr st) +{ + st->refs--; + if (st->refs > 0) + return; + + virResetError(&st->err); + VIR_FREE(st->incoming); + virNetClientProgramFree(st->prog); + VIR_FREE(st); +} + +bool virNetClientStreamMatches(virNetClientStreamPtr st, + virNetMessagePtr msg) +{ + if (virNetClientProgramMatches(st->prog, msg) && + st->proc == msg->header.proc && + st->serial == msg->header.serial) + return 1; + return 0; +} + + +bool virNetClientStreamRaiseError(virNetClientStreamPtr st) +{ + if (st->err.code == VIR_ERR_OK) + return false; + + virRaiseErrorFull(__FILE__, __FUNCTION__, __LINE__, + st->err.domain, + st->err.code, + st->err.level, + st->err.str1, + st->err.str2, + st->err.str3, + st->err.int1, + st->err.int2, + "%s", st->err.message ? st->err.message : _("Unknown error")); + + return true; +} + + +int virNetClientStreamSetError(virNetClientStreamPtr st, + virNetMessagePtr msg) +{ + virNetMessageError err; + int ret = -1; + + if (st->err.code != VIR_ERR_OK) + VIR_DEBUG("Overwriting existing stream error %s", NULLSTR(st->err.message)); + + virResetError(&st->err); + memset(&err, 0, sizeof(err)); + + if (virNetMessageDecodePayload(msg, (xdrproc_t)xdr_virNetMessageError, &err) < 0) + goto cleanup; + + if (err.domain == VIR_FROM_REMOTE && + err.code == VIR_ERR_RPC && + err.level == VIR_ERR_ERROR && + err.message && + STRPREFIX(*err.message, "unknown procedure")) { + st->err.code = VIR_ERR_NO_SUPPORT; + } else { + st->err.code = err.code; + } + st->err.message = *err.message; + *err.message = NULL; + st->err.domain = err.domain; + st->err.level = err.level; + st->err.str1 = *err.str1; + st->err.str2 = *err.str2; + st->err.str3 = *err.str3; + st->err.int1 = err.int1; + st->err.int2 = err.int2; + + ret = 0; + +cleanup: + xdr_free((xdrproc_t)xdr_virNetMessageError, (void*)&err); + return ret; +} + + +int virNetClientStreamQueuePacket(virNetClientStreamPtr st, + virNetMessagePtr msg) +{ + size_t avail = st->incomingLength - st->incomingOffset; + size_t need = msg->bufferLength - msg->bufferOffset; + + if (need > avail) { + size_t extra = need - avail; + if (VIR_REALLOC_N(st->incoming, + st->incomingLength + extra) < 0) { + VIR_DEBUG("Out of memory handling stream data"); + return -1; + } + st->incomingLength += extra; + } + + memcpy(st->incoming + st->incomingOffset, + msg->buffer + msg->bufferOffset, + msg->bufferLength - msg->bufferOffset); + st->incomingOffset += (msg->bufferLength - msg->bufferOffset); + + VIR_DEBUG("Stream incoming data offset %zu length %zu", + st->incomingOffset, st->incomingLength); + return 0; +} + + +int virNetClientStreamSendPacket(virNetClientStreamPtr st, + virNetClientPtr client, + int status, + const char *data, + size_t nbytes) +{ + virNetMessagePtr msg; + bool wantReply; + VIR_DEBUG("st=%p status=%d data=%p nbytes=%zu", st, status, data, nbytes); + + if (!(msg = virNetMessageNew())) + return -1; + + msg->header.prog = virNetClientProgramGetProgram(st->prog); + msg->header.vers = virNetClientProgramGetVersion(st->prog); + msg->header.status = status; + msg->header.type = VIR_NET_STREAM; + msg->header.serial = st->serial; + msg->header.proc = st->proc; + + if (virNetMessageEncodeHeader(msg) < 0) + goto error; + + /* Data packets are async fire&forget, but OK/ERROR packets + * need a synchronous confirmation + */ + if (status == VIR_NET_CONTINUE) { + if (virNetMessageEncodePayloadRaw(msg, data, nbytes) < 0) + goto error; + wantReply = false; + } else { + if (virNetMessageEncodePayloadRaw(msg, NULL, 0) < 0) + goto error; + wantReply = true; + } + + if (virNetClientSend(client, msg, wantReply) < 0) + goto error; + + + return nbytes; + +error: + VIR_FREE(msg); + return -1; +} + +int virNetClientStreamRecvPacket(virNetClientStreamPtr st, + virNetClientPtr client, + char *data, + size_t nbytes, + bool nonblock) +{ + int rv = -1; + VIR_DEBUG("st=%p client=%p data=%p nbytes=%zu nonblock=%d", + st, client, data, nbytes, nonblock); + if (!st->incomingOffset) { + virNetMessagePtr msg; + int ret; + + if (nonblock) { + VIR_DEBUG("Non-blocking mode and no data available"); + rv = -2; + goto cleanup; + } + + if (!(msg = virNetMessageNew())) { + virReportOOMError(); + goto cleanup; + } + + msg->header.prog = virNetClientProgramGetProgram(st->prog); + msg->header.vers = virNetClientProgramGetVersion(st->prog); + msg->header.type = VIR_NET_STREAM; + msg->header.serial = st->serial; + msg->header.proc = st->proc; + + VIR_DEBUG("Dummy packet to wait for stream data"); + ret = virNetClientSend(client, msg, true); + + virNetMessageFree(msg); + + if (ret < 0) + goto cleanup; + } + + VIR_DEBUG("After IO %zu", st->incomingOffset); + if (st->incomingOffset) { + int want = st->incomingOffset; + if (want > nbytes) + want = nbytes; + memcpy(data, st->incoming, want); + if (want < st->incomingOffset) { + memmove(st->incoming, st->incoming + want, st->incomingOffset - want); + st->incomingOffset -= want; + } else { + VIR_FREE(st->incoming); + st->incomingOffset = st->incomingLength = 0; + } + rv = want; + } else { + rv = 0; + } + + virNetClientStreamEventTimerUpdate(st); + +cleanup: + return rv; +} + + +int virNetClientStreamEventAddCallback(virNetClientStreamPtr st, + int events, + virNetClientStreamEventCallback cb, + void *opaque, + virFreeCallback ff) +{ + if (st->cb) { + virNetError(VIR_ERR_INTERNAL_ERROR, + "%s", _("multiple stream callbacks not supported")); + return 1; + } + + virNetClientStreamRef(st); + if ((st->cbTimer = + virEventAddTimeout(-1, + virNetClientStreamEventTimer, + st, + virNetClientStreamEventTimerFree)) < 0) { + virNetClientStreamFree(st); + return -1; + } + + st->cb = cb; + st->cbOpaque = opaque; + st->cbFree = ff; + st->cbEvents = events; + + virNetClientStreamEventTimerUpdate(st); + + return 0; +} + +int virNetClientStreamEventUpdateCallback(virNetClientStreamPtr st, + int events) +{ + if (!st->cb) { + virNetError(VIR_ERR_INTERNAL_ERROR, + "%s", _("no stream callback registered")); + return -1; + } + + st->cbEvents = events; + + virNetClientStreamEventTimerUpdate(st); + + return 0; +} + +int virNetClientStreamEventRemoveCallback(virNetClientStreamPtr st) +{ + if (!st->cb) { + virNetError(VIR_ERR_INTERNAL_ERROR, + "%s", _("no stream callback registered")); + return -1; + } + + if (!st->cbDispatch && + st->cbFree) + (st->cbFree)(st->cbOpaque); + st->cb = NULL; + st->cbOpaque = NULL; + st->cbFree = NULL; + st->cbEvents = 0; + virEventRemoveTimeout(st->cbTimer); + + return 0; +} diff --git a/src/rpc/virnetclientstream.h b/src/rpc/virnetclientstream.h new file mode 100644 index 0000000..6c8d538 --- /dev/null +++ b/src/rpc/virnetclientstream.h @@ -0,0 +1,76 @@ +/* + * virnetclientstream.h: generic network RPC client stream + * + * Copyright (C) 2006-2011 Red Hat, Inc. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * + * Author: Daniel P. Berrange <berrange@redhat.com> + */ + +#ifndef __VIR_NET_CLIENT_STREAM_H__ +# define __VIR_NET_CLIENT_STREAM_H__ + +# include "virnetclientprogram.h" + +typedef struct _virNetClientStream virNetClientStream; +typedef virNetClientStream *virNetClientStreamPtr; + +typedef void (*virNetClientStreamEventCallback)(virNetClientStreamPtr stream, + int events, void *opaque); + +virNetClientStreamPtr virNetClientStreamNew(virNetClientProgramPtr prog, + int proc, + unsigned serial); + +void virNetClientStreamRef(virNetClientStreamPtr st); + +void virNetClientStreamFree(virNetClientStreamPtr st); + +bool virNetClientStreamRaiseError(virNetClientStreamPtr st); + +int virNetClientStreamSetError(virNetClientStreamPtr st, + virNetMessagePtr msg); + +bool virNetClientStreamMatches(virNetClientStreamPtr st, + virNetMessagePtr msg); + +int virNetClientStreamQueuePacket(virNetClientStreamPtr st, + virNetMessagePtr msg); + +int virNetClientStreamSendPacket(virNetClientStreamPtr st, + virNetClientPtr client, + int status, + const char *data, + size_t nbytes); + +int virNetClientStreamRecvPacket(virNetClientStreamPtr st, + virNetClientPtr client, + char *data, + size_t nbytes, + bool nonblock); + +int virNetClientStreamEventAddCallback(virNetClientStreamPtr st, + int events, + virNetClientStreamEventCallback cb, + void *opaque, + virFreeCallback ff); + +int virNetClientStreamEventUpdateCallback(virNetClientStreamPtr st, + int events); +int virNetClientStreamEventRemoveCallback(virNetClientStreamPtr st); + + +#endif /* __VIR_NET_CLIENT_STREAM_H__ */ -- 1.7.4.4

Move the daemon/remote_generator.pl to src/rpc/gendispatch.pl and move the src/remote/rpcgen_fix.pl to src/rpc/genprotocol.pl * daemon/Makefile.am: Update for new name/location of generator * src/Makefile.am: Update for new name/location of generator --- daemon/Makefile.am | 41 ++++++++++---------- src/Makefile.am | 32 +++++++++------- .../remote_generator.pl => src/rpc/gendispatch.pl | 0 src/{remote/rpcgen_fix.pl => rpc/genprotocol.pl} | 0 4 files changed, 38 insertions(+), 35 deletions(-) rename daemon/remote_generator.pl => src/rpc/gendispatch.pl (100%) rename src/{remote/rpcgen_fix.pl => rpc/genprotocol.pl} (100%) mode change 100755 => 100644 diff --git a/daemon/Makefile.am b/daemon/Makefile.am index 92d154f..efc2e8e 100644 --- a/daemon/Makefile.am +++ b/daemon/Makefile.am @@ -28,7 +28,6 @@ AVAHI_SOURCES = \ DISTCLEANFILES = EXTRA_DIST = \ - remote_generator.pl \ remote_dispatch_bodies.h \ qemu_dispatch_bodies.h \ libvirtd.conf \ @@ -56,54 +55,54 @@ BUILT_SOURCES = REMOTE_PROTOCOL = $(top_srcdir)/src/remote/remote_protocol.x QEMU_PROTOCOL = $(top_srcdir)/src/remote/qemu_protocol.x -$(srcdir)/remote_dispatch_prototypes.h: $(srcdir)/remote_generator.pl \ +$(srcdir)/remote_dispatch_prototypes.h: $(srcdir)/../src/rpc/gendispatch.pl \ $(REMOTE_PROTOCOL) - $(AM_V_GEN)perl -w $(srcdir)/remote_generator.pl -c -p remote \ + $(AM_V_GEN)perl -w $(srcdir)/../src/rpc/gendispatch.pl -c -p remote \ $(REMOTE_PROTOCOL) > $@ -$(srcdir)/remote_dispatch_table.h: $(srcdir)/remote_generator.pl \ +$(srcdir)/remote_dispatch_table.h: $(srcdir)/../src/rpc/gendispatch.pl \ $(REMOTE_PROTOCOL) - $(AM_V_GEN)perl -w $(srcdir)/remote_generator.pl -c -t remote \ + $(AM_V_GEN)perl -w $(srcdir)/../src/rpc/gendispatch.pl -c -t remote \ $(REMOTE_PROTOCOL) > $@ -$(srcdir)/remote_dispatch_args.h: $(srcdir)/remote_generator.pl \ +$(srcdir)/remote_dispatch_args.h: $(srcdir)/../src/rpc/gendispatch.pl \ $(REMOTE_PROTOCOL) - $(AM_V_GEN)perl -w $(srcdir)/remote_generator.pl -c -a remote \ + $(AM_V_GEN)perl -w $(srcdir)/../src/rpc/gendispatch.pl -c -a remote \ $(REMOTE_PROTOCOL) > $@ -$(srcdir)/remote_dispatch_ret.h: $(srcdir)/remote_generator.pl \ +$(srcdir)/remote_dispatch_ret.h: $(srcdir)/../src/rpc/gendispatch.pl \ $(REMOTE_PROTOCOL) - $(AM_V_GEN)perl -w $(srcdir)/remote_generator.pl -c -r remote \ + $(AM_V_GEN)perl -w $(srcdir)/../src/rpc/gendispatch.pl -c -r remote \ $(REMOTE_PROTOCOL) > $@ -$(srcdir)/remote_dispatch_bodies.h: $(srcdir)/remote_generator.pl \ +$(srcdir)/remote_dispatch_bodies.h: $(srcdir)/../src/rpc/gendispatch.pl \ $(REMOTE_PROTOCOL) - $(AM_V_GEN)perl -w $(srcdir)/remote_generator.pl -c -b remote \ + $(AM_V_GEN)perl -w $(srcdir)/../src/rpc/gendispatch.pl -c -b remote \ $(REMOTE_PROTOCOL) > $@ -$(srcdir)/qemu_dispatch_prototypes.h: $(srcdir)/remote_generator.pl \ +$(srcdir)/qemu_dispatch_prototypes.h: $(srcdir)/../src/rpc/gendispatch.pl \ $(QEMU_PROTOCOL) - $(AM_V_GEN)perl -w $(srcdir)/remote_generator.pl -p qemu \ + $(AM_V_GEN)perl -w $(srcdir)/../src/rpc/gendispatch.pl -p qemu \ $(QEMU_PROTOCOL) > $@ -$(srcdir)/qemu_dispatch_table.h: $(srcdir)/remote_generator.pl \ +$(srcdir)/qemu_dispatch_table.h: $(srcdir)/../src/rpc/gendispatch.pl \ $(QEMU_PROTOCOL) - $(AM_V_GEN)perl -w $(srcdir)/remote_generator.pl -t qemu \ + $(AM_V_GEN)perl -w $(srcdir)/../src/rpc/gendispatch.pl -t qemu \ $(QEMU_PROTOCOL) > $@ -$(srcdir)/qemu_dispatch_args.h: $(srcdir)/remote_generator.pl \ +$(srcdir)/qemu_dispatch_args.h: $(srcdir)/../src/rpc/gendispatch.pl \ $(QEMU_PROTOCOL) - $(AM_V_GEN)perl -w $(srcdir)/remote_generator.pl -a qemu \ + $(AM_V_GEN)perl -w $(srcdir)/../src/rpc/gendispatch.pl -a qemu \ $(QEMU_PROTOCOL) > $@ -$(srcdir)/qemu_dispatch_ret.h: $(srcdir)/remote_generator.pl \ +$(srcdir)/qemu_dispatch_ret.h: $(srcdir)/../src/rpc/gendispatch.pl \ $(QEMU_PROTOCOL) - $(AM_V_GEN)perl -w $(srcdir)/remote_generator.pl -r qemu \ + $(AM_V_GEN)perl -w $(srcdir)/../src/rpc/gendispatch.pl -r qemu \ $(QEMU_PROTOCOL) > $@ -$(srcdir)/qemu_dispatch_bodies.h: $(srcdir)/remote_generator.pl \ +$(srcdir)/qemu_dispatch_bodies.h: $(srcdir)/../src/rpc/gendispatch.pl \ $(QEMU_PROTOCOL) - $(AM_V_GEN)perl -w $(srcdir)/remote_generator.pl -b qemu \ + $(AM_V_GEN)perl -w $(srcdir)/../src/rpc/gendispatch.pl -b qemu \ $(QEMU_PROTOCOL) > $@ if WITH_LIBVIRTD diff --git a/src/Makefile.am b/src/Makefile.am index e6f2a02..20f9ca6 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -166,14 +166,14 @@ REMOTE_PROTOCOL = $(srcdir)/remote/remote_protocol.x QEMU_PROTOCOL = $(srcdir)/remote/qemu_protocol.x REMOTE_DRIVER_PROTOCOL = $(REMOTE_PROTOCOL) $(QEMU_PROTOCOL) -$(srcdir)/remote/remote_client_bodies.h: $(REMOTE_PROTOCOL) \ - $(top_srcdir)/daemon/remote_generator.pl - $(AM_V_GEN)perl -w $(top_srcdir)/daemon/remote_generator.pl \ +remote/remote_client_bodies.h: $(srcdir)/rpc/gendispatch.pl \ + $(REMOTE_PROTOCOL) + $(AM_V_GEN)perl -w $(srcdir)/rpc/gendispatch.pl \ -c -k remote $(REMOTE_PROTOCOL) > $@ -$(srcdir)/remote/qemu_client_bodies.h: $(QEMU_PROTOCOL) \ - $(top_srcdir)/daemon/remote_generator.pl - $(AM_V_GEN)perl -w $(top_srcdir)/daemon/remote_generator.pl \ +remote/qemu_client_bodies.h: $(srcdir)/rpc/gendispatch.pl \ + $(QEMU_PROTOCOL) + $(AM_V_GEN)perl -w $(srcdir)/rpc/gendispatch.pl \ -k qemu $(QEMU_PROTOCOL) > $@ REMOTE_DRIVER_SOURCES = \ @@ -182,8 +182,7 @@ REMOTE_DRIVER_SOURCES = \ $(REMOTE_DRIVER_GENERATED) EXTRA_DIST += $(REMOTE_DRIVER_PROTOCOL) \ - $(REMOTE_DRIVER_GENERATED) \ - remote/rpcgen_fix.pl + $(REMOTE_DRIVER_GENERATED) # Ensure that we don't change the struct or member names or member ordering # in remote_protocol.x The embedded perl below needs a few comments, and @@ -552,13 +551,13 @@ libvirt_driver_remote_la_SOURCES = $(REMOTE_DRIVER_SOURCES) $(srcdir)/remote/remote_driver.c: $(REMOTE_DRIVER_GENERATED) -%protocol.c: %protocol.x %protocol.h $(srcdir)/remote/rpcgen_fix.pl - $(AM_V_GEN)perl -w $(srcdir)/remote/rpcgen_fix.pl $(RPCGEN) -c \ - $< $@ +%protocol.c: %protocol.x %protocol.h $(srcdir)/rpc/genprotocol.pl + $(AM_V_GEN)perl -w $(srcdir)/rpc/genprotocol.pl $(RPCGEN) -c \ + $< $@ -%protocol.h: %protocol.x $(srcdir)/remote/rpcgen_fix.pl - $(AM_V_GEN)perl -w $(srcdir)/remote/rpcgen_fix.pl $(RPCGEN) -h \ - $< $@ +%protocol.h: %protocol.x $(srcdir)/rpc/genprotocol.pl + $(AM_V_GEN)perl -w $(srcdir)/rpc/genprotocol.pl $(RPCGEN) -h \ + $< $@ endif @@ -1156,6 +1155,11 @@ EXTRA_DIST += $(LIBVIRT_QEMU_SYMBOL_FILE) noinst_LTLIBRARIES += libvirt-net-rpc.la libvirt-net-rpc-server.la libvirt-net-rpc-client.la +EXTRA_DIST += \ + rpc/virnetprotocol.x \ + rpc/gendispatch.pl \ + rpc/genprotocol.pl + libvirt_net_rpc_la_SOURCES = \ rpc/virnetmessage.h rpc/virnetmessage.c \ rpc/virnetprotocol.h rpc/virnetprotocol.c \ diff --git a/daemon/remote_generator.pl b/src/rpc/gendispatch.pl similarity index 100% rename from daemon/remote_generator.pl rename to src/rpc/gendispatch.pl diff --git a/src/remote/rpcgen_fix.pl b/src/rpc/genprotocol.pl old mode 100755 new mode 100644 similarity index 100% rename from src/remote/rpcgen_fix.pl rename to src/rpc/genprotocol.pl -- 1.7.4.4
participants (1)
-
Daniel P. Berrange