For streams validation, we weren't consistent on whether to
use VIR_FROM_NONE or VIR_FROM_STREAMS. Furthermore, in
many API, we want to ensure that a stream is tied to the
same connection as the other object we are operating on;
while other API failed to validate the stream at all.
Similar to previous patches, use a common macro to make it
nicer.
* src/datatypes.h (virCheckStreamReturn, virCheckStreamGoto):
New macros.
(VIR_IS_STREAM, VIR_IS_CONNECTED_STREAM): Drop unused macros.
* src/libvirt.c: Use macro throughout.
(virLibStreamError): Drop unused macro.
Signed-off-by: Eric Blake <eblake(a)redhat.com>
---
src/datatypes.h | 29 +++++++++++--
src/libvirt.c | 128 ++++++++++++++++++++++----------------------------------
2 files changed, 76 insertions(+), 81 deletions(-)
diff --git a/src/datatypes.h b/src/datatypes.h
index 024a2e3..6f092a7 100644
--- a/src/datatypes.h
+++ b/src/datatypes.h
@@ -192,10 +192,31 @@ extern virClassPtr virStoragePoolClass;
} \
} while (0)
-# define VIR_IS_STREAM(obj) \
- (virObjectIsClass((obj), virStreamClass))
-# define VIR_IS_CONNECTED_STREAM(obj) \
- (VIR_IS_STREAM(obj) && virObjectIsClass((obj)->conn, virConnectClass))
+# define virCheckStreamReturn(obj, retval) \
+ do { \
+ virStreamPtr _st = (obj); \
+ if (!virObjectIsClass(_st, virStreamClass) || \
+ !virObjectIsClass(_st->conn, virConnectClass)) { \
+ virReportErrorHelper(VIR_FROM_STREAMS, \
+ VIR_ERR_INVALID_STREAM, \
+ __FILE__, __FUNCTION__, __LINE__, \
+ __FUNCTION__); \
+ virDispatchError(NULL); \
+ return retval; \
+ } \
+ } while (0)
+# define virCheckStreamGoto(obj, label) \
+ do { \
+ virStreamPtr _st = (obj); \
+ if (!virObjectIsClass(_st, virStreamClass) || \
+ !virObjectIsClass(_st->conn, virConnectClass)) { \
+ virReportErrorHelper(VIR_FROM_STREAMS, \
+ VIR_ERR_INVALID_STREAM, \
+ __FILE__, __FUNCTION__, __LINE__, \
+ __FUNCTION__); \
+ goto label; \
+ } \
+ } while (0)
# define VIR_IS_NWFILTER(obj) \
(virObjectIsClass((obj), virNWFilterClass))
diff --git a/src/libvirt.c b/src/libvirt.c
index eaa4c89..b460c00 100644
--- a/src/libvirt.c
+++ b/src/libvirt.c
@@ -518,9 +518,6 @@ DllMain(HINSTANCE instance ATTRIBUTE_UNUSED,
#define virLibDomainError(code, ...) \
virReportErrorHelper(VIR_FROM_DOM, code, __FILE__, \
__FUNCTION__, __LINE__, __VA_ARGS__)
-#define virLibStreamError(code, ...) \
- virReportErrorHelper(VIR_FROM_STREAMS, code, __FILE__, \
- __FUNCTION__, __LINE__, __VA_ARGS__)
#define virLibNWFilterError(code, ...) \
virReportErrorHelper(VIR_FROM_NWFILTER, code, __FILE__, \
__FUNCTION__, __LINE__, __VA_ARGS__)
@@ -3076,14 +3073,18 @@ virDomainScreenshot(virDomainPtr domain,
virResetLastError();
virCheckDomainReturn(domain, NULL);
- if (!VIR_IS_STREAM(stream)) {
- virLibConnError(VIR_ERR_INVALID_STREAM, __FUNCTION__);
- return NULL;
+ virCheckStreamGoto(stream, error);
+ virCheckReadOnlyGoto(domain->conn->flags, error);
+
+ if (domain->conn != stream->conn) {
+ virReportInvalidArg(stream,
+ _("stream in %s must match connection of domain
'%s'"),
+ __FUNCTION__, domain->name);
+ goto error;
}
- virCheckReadOnlyGoto(domain->conn->flags | stream->conn->flags, error);
if (domain->conn->driver->domainScreenshot) {
- char * ret;
+ char *ret;
ret = domain->conn->driver->domainScreenshot(domain, stream,
screen, flags);
@@ -13645,14 +13646,16 @@ virStorageVolDownload(virStorageVolPtr vol,
virResetLastError();
virCheckStorageVolReturn(vol, -1);
+ virCheckStreamGoto(stream, error);
+ virCheckReadOnlyGoto(vol->conn->flags, error);
- if (!VIR_IS_STREAM(stream)) {
- virLibConnError(VIR_ERR_INVALID_STREAM, __FUNCTION__);
- return -1;
+ if (vol->conn != stream->conn) {
+ virReportInvalidArg(stream,
+ _("stream in %s must match connection of volume
'%s'"),
+ __FUNCTION__, vol->name);
+ goto error;
}
- virCheckReadOnlyGoto(vol->conn->flags | stream->conn->flags, error);
-
if (vol->conn->storageDriver &&
vol->conn->storageDriver->storageVolDownload) {
int ret;
@@ -13709,14 +13712,16 @@ virStorageVolUpload(virStorageVolPtr vol,
virResetLastError();
virCheckStorageVolReturn(vol, -1);
+ virCheckStreamGoto(stream, error);
+ virCheckReadOnlyGoto(vol->conn->flags, error);
- if (!VIR_IS_STREAM(stream)) {
- virLibConnError(VIR_ERR_INVALID_STREAM, __FUNCTION__);
- return -1;
+ if (vol->conn != stream->conn) {
+ virReportInvalidArg(stream,
+ _("stream in %s must match connection of volume
'%s'"),
+ __FUNCTION__, vol->name);
+ goto error;
}
- virCheckReadOnlyGoto(vol->conn->flags | stream->conn->flags, error);
-
if (vol->conn->storageDriver &&
vol->conn->storageDriver->storageVolUpload) {
int ret;
@@ -15632,11 +15637,8 @@ virStreamRef(virStreamPtr stream)
virResetLastError();
- if ((!VIR_IS_CONNECTED_STREAM(stream))) {
- virLibConnError(VIR_ERR_INVALID_STREAM, __FUNCTION__);
- virDispatchError(NULL);
- return -1;
- }
+ virCheckStreamReturn(stream, -1);
+
virObjectRef(stream);
return 0;
}
@@ -15715,12 +15717,7 @@ virStreamSend(virStreamPtr stream,
virResetLastError();
- if (!VIR_IS_CONNECTED_STREAM(stream)) {
- virLibConnError(VIR_ERR_INVALID_CONN, __FUNCTION__);
- virDispatchError(NULL);
- return -1;
- }
-
+ virCheckStreamReturn(stream, -1);
virCheckNonNullArgGoto(data, error);
if (stream->driver &&
@@ -15813,12 +15810,7 @@ virStreamRecv(virStreamPtr stream,
virResetLastError();
- if (!VIR_IS_CONNECTED_STREAM(stream)) {
- virLibConnError(VIR_ERR_INVALID_CONN, __FUNCTION__);
- virDispatchError(NULL);
- return -1;
- }
-
+ virCheckStreamReturn(stream, -1);
virCheckNonNullArgGoto(data, error);
if (stream->driver &&
@@ -15892,12 +15884,7 @@ virStreamSendAll(virStreamPtr stream,
virResetLastError();
- if (!VIR_IS_CONNECTED_STREAM(stream)) {
- virLibConnError(VIR_ERR_INVALID_CONN, __FUNCTION__);
- virDispatchError(NULL);
- return -1;
- }
-
+ virCheckStreamReturn(stream, -1);
virCheckNonNullArgGoto(handler, cleanup);
if (stream->flags & VIR_STREAM_NONBLOCK) {
@@ -15990,12 +15977,7 @@ virStreamRecvAll(virStreamPtr stream,
virResetLastError();
- if (!VIR_IS_CONNECTED_STREAM(stream)) {
- virLibConnError(VIR_ERR_INVALID_CONN, __FUNCTION__);
- virDispatchError(NULL);
- return -1;
- }
-
+ virCheckStreamReturn(stream, -1);
virCheckNonNullArgGoto(handler, cleanup);
if (stream->flags & VIR_STREAM_NONBLOCK) {
@@ -16064,11 +16046,7 @@ virStreamEventAddCallback(virStreamPtr stream,
virResetLastError();
- if (!VIR_IS_CONNECTED_STREAM(stream)) {
- virLibConnError(VIR_ERR_INVALID_CONN, __FUNCTION__);
- virDispatchError(NULL);
- return -1;
- }
+ virCheckStreamReturn(stream, -1);
if (stream->driver &&
stream->driver->streamEventAddCallback) {
@@ -16107,11 +16085,7 @@ virStreamEventUpdateCallback(virStreamPtr stream,
virResetLastError();
- if (!VIR_IS_CONNECTED_STREAM(stream)) {
- virLibConnError(VIR_ERR_INVALID_CONN, __FUNCTION__);
- virDispatchError(NULL);
- return -1;
- }
+ virCheckStreamReturn(stream, -1);
if (stream->driver &&
stream->driver->streamEventUpdateCallback) {
@@ -16145,11 +16119,7 @@ virStreamEventRemoveCallback(virStreamPtr stream)
virResetLastError();
- if (!VIR_IS_CONNECTED_STREAM(stream)) {
- virLibConnError(VIR_ERR_INVALID_CONN, __FUNCTION__);
- virDispatchError(NULL);
- return -1;
- }
+ virCheckStreamReturn(stream, -1);
if (stream->driver &&
stream->driver->streamEventRemoveCallback) {
@@ -16190,11 +16160,7 @@ virStreamFinish(virStreamPtr stream)
virResetLastError();
- if (!VIR_IS_CONNECTED_STREAM(stream)) {
- virLibConnError(VIR_ERR_INVALID_CONN, __FUNCTION__);
- virDispatchError(NULL);
- return -1;
- }
+ virCheckStreamReturn(stream, -1);
if (stream->driver &&
stream->driver->streamFinish) {
@@ -16233,11 +16199,7 @@ virStreamAbort(virStreamPtr stream)
virResetLastError();
- if (!VIR_IS_CONNECTED_STREAM(stream)) {
- virLibConnError(VIR_ERR_INVALID_CONN, __FUNCTION__);
- virDispatchError(NULL);
- return -1;
- }
+ virCheckStreamReturn(stream, -1);
if (!stream->driver) {
VIR_DEBUG("aborting unused stream");
@@ -16281,11 +16243,7 @@ virStreamFree(virStreamPtr stream)
virResetLastError();
- if (!VIR_IS_CONNECTED_STREAM(stream)) {
- virLibConnError(VIR_ERR_INVALID_CONN, __FUNCTION__);
- virDispatchError(NULL);
- return -1;
- }
+ virCheckStreamReturn(stream, -1);
/* XXX Enforce shutdown before free'ing resources ? */
@@ -19293,8 +19251,16 @@ virDomainOpenConsole(virDomainPtr dom,
virCheckDomainReturn(dom, -1);
conn = dom->conn;
+ virCheckStreamGoto(st, error);
virCheckReadOnlyGoto(conn->flags, error);
+ if (conn != st->conn) {
+ virReportInvalidArg(st,
+ _("stream in %s must match connection of domain
'%s'"),
+ __FUNCTION__, dom->name);
+ goto error;
+ }
+
if (conn->driver->domainOpenConsole) {
int ret;
ret = conn->driver->domainOpenConsole(dom, dev_name, st, flags);
@@ -19349,8 +19315,16 @@ virDomainOpenChannel(virDomainPtr dom,
virCheckDomainReturn(dom, -1);
conn = dom->conn;
+ virCheckStreamGoto(st, error);
virCheckReadOnlyGoto(conn->flags, error);
+ if (conn != st->conn) {
+ virReportInvalidArg(st,
+ _("stream in %s must match connection of domain
'%s'"),
+ __FUNCTION__, dom->name);
+ goto error;
+ }
+
if (conn->driver->domainOpenChannel) {
int ret;
ret = conn->driver->domainOpenChannel(dom, name, st, flags);
--
1.8.4.2