/* * virnetclientstream.c: generic network RPC client stream * * Copyright (C) 2006-2011 Red Hat, Inc. * * This library is free software; you can redistribute it and/or * modify it under the terms of the GNU Lesser General Public * License as published by the Free Software Foundation; either * version 2.1 of the License, or (at your option) any later version. * * This library is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with this library. If not, see * . */ #include #include "virnetclientstream.h" #include "virnetclient.h" #include "virerror.h" #include "virlog.h" #include "virthread.h" #define VIR_FROM_THIS VIR_FROM_RPC VIR_LOG_INIT("rpc.netclientstream"); struct _virNetClientStream { virObjectLockable parent; virNetClientProgram *prog; int proc; unsigned serial; virError err; /* XXX this buffer is unbounded if the client * app has domain events registered, since packets * may be read off wire, while app isn't ready to * recv them. Figure out how to address this some * time by stopping consuming any incoming data * off the socket.... */ virNetMessage *rx; bool incomingEOF; virNetClientStreamClosed closed; bool allowSkip; long long holeLength; /* Size of incoming hole in stream. */ virNetClientStreamEventCallback cb; void *cbOpaque; virFreeCallback cbFree; int cbEvents; int cbTimer; int cbDispatch; }; static virClass *virNetClientStreamClass; static void virNetClientStreamDispose(void *obj); static int virNetClientStreamOnceInit(void) { if (!VIR_CLASS_NEW(virNetClientStream, virClassForObjectLockable())) return -1; return 0; } VIR_ONCE_GLOBAL_INIT(virNetClientStream); static void virNetClientStreamEventTimerUpdate(virNetClientStream *st) { if (!st->cb) return; VIR_DEBUG("Check timer rx=%p cbEvents=%d", st->rx, st->cbEvents); if (((st->rx || st->incomingEOF || st->err.code != VIR_ERR_OK || st->closed) && (st->cbEvents & VIR_STREAM_EVENT_READABLE)) || (st->cbEvents & VIR_STREAM_EVENT_WRITABLE)) { VIR_DEBUG("Enabling event timer"); virEventUpdateTimeout(st->cbTimer, 0); } else { VIR_DEBUG("Disabling event timer"); virEventUpdateTimeout(st->cbTimer, -1); } } static void virNetClientStreamEventTimer(int timer G_GNUC_UNUSED, void *opaque) { virNetClientStream *st = opaque; int events = 0; virObjectLock(st); if (st->cb && (st->cbEvents & VIR_STREAM_EVENT_READABLE) && (st->rx || st->incomingEOF || st->err.code != VIR_ERR_OK || st->closed)) events |= VIR_STREAM_EVENT_READABLE; if (st->cb && (st->cbEvents & VIR_STREAM_EVENT_WRITABLE)) events |= VIR_STREAM_EVENT_WRITABLE; VIR_DEBUG("Got Timer dispatch events=%d cbEvents=%d rx=%p", events, st->cbEvents, st->rx); if (events) { virNetClientStreamEventCallback cb = st->cb; void *cbOpaque = st->cbOpaque; virFreeCallback cbFree = st->cbFree; st->cbDispatch = 1; virObjectUnlock(st); (cb)(st, events, cbOpaque); virObjectLock(st); st->cbDispatch = 0; if (!st->cb && cbFree) (cbFree)(cbOpaque); } virObjectUnlock(st); } virNetClientStream *virNetClientStreamNew(virNetClientProgram *prog, int proc, unsigned serial, bool allowSkip) { virNetClientStream *st; if (virNetClientStreamInitialize() < 0) return NULL; if (!(st = virObjectLockableNew(virNetClientStreamClass))) return NULL; st->prog = virObjectRef(prog); st->proc = proc; st->serial = serial; st->allowSkip = allowSkip; return st; } void virNetClientStreamDispose(void *obj) { virNetClientStream *st = obj; virResetError(&st->err); while (st->rx) { virNetMessage *msg = st->rx; virNetMessageQueueServe(&st->rx); virNetMessageFree(msg); } virObjectUnref(st->prog); } bool virNetClientStreamMatches(virNetClientStream *st, virNetMessage *msg) { bool match = false; virObjectLock(st); if (virNetClientProgramMatches(st->prog, msg) && st->proc == msg->header.proc && st->serial == msg->header.serial) match = true; virObjectUnlock(st); return match; } static void virNetClientStreamRaiseError(virNetClientStream *st) { virRaiseErrorFull(__FILE__, __FUNCTION__, __LINE__, st->err.domain, st->err.code, st->err.level, st->err.str1, st->err.str2, st->err.str3, st->err.int1, st->err.int2, "%s", st->err.message ? st->err.message : _("Unknown error")); } /* MUST be called under stream or client lock */ int virNetClientStreamCheckState(virNetClientStream *st) { if (st->err.code != VIR_ERR_OK) { virNetClientStreamRaiseError(st); return -1; } if (st->closed) { virReportError(VIR_ERR_OPERATION_FAILED, "%s", _("stream is closed")); return -1; } return 0; } /* MUST be called under stream or client lock. This should * be called only for message that expect reply. */ int virNetClientStreamCheckSendStatus(virNetClientStream *st, virNetMessage *msg) { if (st->err.code != VIR_ERR_OK) { virNetClientStreamRaiseError(st); return -1; } /* We can not check if the message is dummy in a usual way * by checking msg->bufferLength because at this point message payload * is cleared. As caller must not call this function for messages * not expecting reply we can check for dummy messages just by status. */ if (msg->header.status == VIR_NET_CONTINUE) { if (st->closed) { virReportError(VIR_ERR_OPERATION_FAILED, "%s", _("stream is closed")); return -1; } return 0; } else if (msg->header.status == VIR_NET_OK && st->closed != VIR_NET_CLIENT_STREAM_CLOSED_FINISHED) { virReportError(VIR_ERR_OPERATION_FAILED, "%s", _("stream aborted by another thread")); return -1; } return 0; } void virNetClientStreamSetClosed(virNetClientStream *st, virNetClientStreamClosed closed) { virObjectLock(st); st->closed = closed; virNetClientStreamEventTimerUpdate(st); virObjectUnlock(st); } int virNetClientStreamSetError(virNetClientStream *st, virNetMessage *msg) { virNetMessageError err; int ret = -1; virObjectLock(st); if (st->err.code != VIR_ERR_OK) VIR_DEBUG("Overwriting existing stream error %s", NULLSTR(st->err.message)); virResetError(&st->err); memset(&err, 0, sizeof(err)); if (virNetMessageDecodePayload(msg, (xdrproc_t)xdr_virNetMessageError, &err) < 0) goto cleanup; if (err.domain == VIR_FROM_REMOTE && err.code == VIR_ERR_RPC && err.level == VIR_ERR_ERROR && err.message && STRPREFIX(*err.message, "unknown procedure")) { st->err.code = VIR_ERR_NO_SUPPORT; } else { st->err.code = err.code; } if (err.message) { st->err.message = g_steal_pointer(err.message); } st->err.domain = err.domain; st->err.level = err.level; if (err.str1) { st->err.str1 = g_steal_pointer(err.str1); } if (err.str2) { st->err.str2 = g_steal_pointer(err.str2); } if (err.str3) { st->err.str3 = g_steal_pointer(err.str3); } st->err.int1 = err.int1; st->err.int2 = err.int2; virNetClientStreamEventTimerUpdate(st); ret = 0; cleanup: xdr_free((xdrproc_t)xdr_virNetMessageError, (void*)&err); virObjectUnlock(st); return ret; } int virNetClientStreamQueuePacket(virNetClientStream *st, virNetMessage *msg) { virNetMessage *tmp_msg; VIR_DEBUG("Incoming stream message: stream=%p message=%p", st, msg); if (msg->bufferLength == msg->bufferOffset) { /* No payload means end of the stream. */ virObjectLock(st); st->incomingEOF = true; virNetClientStreamEventTimerUpdate(st); virObjectUnlock(st); return 0; } /* Unfortunately, we must allocate new message as the one we * get in @msg is going to be cleared later in the process. */ if (!(tmp_msg = virNetMessageNew(false))) return -1; /* Copy header */ memcpy(&tmp_msg->header, &msg->header, sizeof(msg->header)); /* Steal message buffer */ tmp_msg->buffer = g_steal_pointer(&msg->buffer); tmp_msg->bufferLength = msg->bufferLength; tmp_msg->bufferOffset = msg->bufferOffset; msg->bufferLength = msg->bufferOffset = 0; virObjectLock(st); /* Don't distinguish VIR_NET_STREAM and VIR_NET_STREAM_SKIP * here just yet. We want in order processing! */ virNetMessageQueuePush(&st->rx, tmp_msg); virNetClientStreamEventTimerUpdate(st); virObjectUnlock(st); return 0; } int virNetClientStreamSendPacket(virNetClientStream *st, virNetClient *client, int status, const char *data, size_t nbytes) { virNetMessage *msg; VIR_DEBUG("st=%p status=%d data=%p nbytes=%zu", st, status, data, nbytes); if (!(msg = virNetMessageNew(false))) return -1; virObjectLock(st); msg->header.prog = virNetClientProgramGetProgram(st->prog); msg->header.vers = virNetClientProgramGetVersion(st->prog); msg->header.status = status; msg->header.type = VIR_NET_STREAM; msg->header.serial = st->serial; msg->header.proc = st->proc; virObjectUnlock(st); if (virNetMessageEncodeHeader(msg) < 0) goto error; /* Data packets are async fire&forget, but OK/ERROR packets * need a synchronous confirmation */ if (status == VIR_NET_CONTINUE) { if (virNetMessageEncodePayloadRaw(msg, data, nbytes) < 0) goto error; } else { if (virNetMessageEncodePayloadRaw(msg, NULL, 0) < 0) goto error; } if (virNetClientSendStream(client, msg, st) < 0) goto error; virNetMessageFree(msg); return nbytes; error: virNetMessageFree(msg); return -1; } static int virNetClientStreamSetHole(virNetClientStream *st, long long length, unsigned int flags) { virCheckFlags(0, -1); virCheckPositiveArgReturn(length, -1); /* Shouldn't happen, But it's better to safe than sorry. */ if (st->holeLength) { virReportError(VIR_ERR_INTERNAL_ERROR, _("unprocessed hole of size %1$lld already in the queue"), st->holeLength); return -1; } st->holeLength += length; return 0; } /** * virNetClientStreamHandleHole: * @client: client * @st: stream * * Called whenever current message processed in the stream is * VIR_NET_STREAM_HOLE. The stream @st is expected to be locked * already. * * Returns: 0 on success, * -1 otherwise. */ static int virNetClientStreamHandleHole(virNetClient *client, virNetClientStream *st) { virNetMessage *msg; virNetStreamHole data; int ret = -1; VIR_DEBUG("client=%p st=%p", client, st); msg = st->rx; memset(&data, 0, sizeof(data)); /* We should not be called unless there's VIR_NET_STREAM_HOLE * message at the head of the list. But doesn't hurt to check */ if (!msg) { virReportError(VIR_ERR_INTERNAL_ERROR, "%s", _("No message in the queue")); goto cleanup; } if (msg->header.type != VIR_NET_STREAM_HOLE) { virReportError(VIR_ERR_INTERNAL_ERROR, _("Invalid message prog=%1$d type=%2$d serial=%3$u proc=%4$d"), msg->header.prog, msg->header.type, msg->header.serial, msg->header.proc); goto cleanup; } /* Server should not send us VIR_NET_STREAM_HOLE unless we * have requested so. But does not hurt to check ... */ if (!st->allowSkip) { virReportError(VIR_ERR_RPC, "%s", _("Unexpected stream hole")); goto cleanup; } if (virNetMessageDecodePayload(msg, (xdrproc_t)xdr_virNetStreamHole, &data) < 0) { virReportError(VIR_ERR_INTERNAL_ERROR, "%s", _("Malformed stream hole packet")); goto cleanup; } virNetMessageQueueServe(&st->rx); virNetMessageFree(msg); if (virNetClientStreamSetHole(st, data.length, data.flags) < 0) goto cleanup; ret = 0; cleanup: if (ret < 0) { /* Abort stream? */ } return ret; } int virNetClientStreamRecvPacket(virNetClientStream *st, virNetClient *client, char *data, size_t nbytes, bool nonblock, unsigned int flags) { int rv = -1; size_t want; VIR_DEBUG("st=%p client=%p data=%p nbytes=%zu nonblock=%d flags=0x%x", st, client, data, nbytes, nonblock, flags); virCheckFlags(VIR_STREAM_RECV_STOP_AT_HOLE, -1); virObjectLock(st); reread: if (virNetClientStreamCheckState(st) < 0) goto cleanup; if (!st->rx && !st->incomingEOF) { virNetMessage *msg; int ret; if (nonblock) { VIR_DEBUG("Non-blocking mode and no data available"); rv = -2; goto cleanup; } if (!(msg = virNetMessageNew(false))) goto cleanup; msg->header.prog = virNetClientProgramGetProgram(st->prog); msg->header.vers = virNetClientProgramGetVersion(st->prog); msg->header.type = VIR_NET_STREAM; msg->header.serial = st->serial; msg->header.proc = st->proc; msg->header.status = VIR_NET_CONTINUE; VIR_DEBUG("Dummy packet to wait for stream data"); virObjectUnlock(st); ret = virNetClientSendStream(client, msg, st); virObjectLock(st); virNetMessageFree(msg); if (ret < 0) goto cleanup; } VIR_DEBUG("After IO rx=%p", st->rx); if (st->rx && st->rx->header.type == VIR_NET_STREAM_HOLE && st->holeLength == 0) { /* Handle skip sent to us by server. */ if (virNetClientStreamHandleHole(client, st) < 0) goto cleanup; } if (!st->rx && !st->incomingEOF && st->holeLength == 0) { if (nonblock) { VIR_DEBUG("Non-blocking mode and no data available"); rv = -2; goto cleanup; } /* We have consumed all packets from incoming queue but those * were only skip packets, no data. Read the stream again. */ goto reread; } want = nbytes; if (st->holeLength) { /* Pretend holeLength zeroes was read from stream. */ size_t len = want; /* Yes, pretend unless we are asked not to. */ if (flags & VIR_STREAM_RECV_STOP_AT_HOLE) { /* No error reporting here. Caller knows what they are doing. */ rv = -3; goto cleanup; } if (len > st->holeLength) len = st->holeLength; memset(data, 0, len); st->holeLength -= len; want -= len; } while (want && st->rx && st->rx->header.type == VIR_NET_STREAM) { virNetMessage *msg = st->rx; size_t len = want; if (len > msg->bufferLength - msg->bufferOffset) len = msg->bufferLength - msg->bufferOffset; if (!len) break; memcpy(data + (nbytes - want), msg->buffer + msg->bufferOffset, len); want -= len; msg->bufferOffset += len; if (msg->bufferOffset == msg->bufferLength) { virNetMessageQueueServe(&st->rx); virNetMessageFree(msg); } } rv = nbytes - want; virNetClientStreamEventTimerUpdate(st); cleanup: virObjectUnlock(st); return rv; } int virNetClientStreamSendHole(virNetClientStream *st, virNetClient *client, long long length, unsigned int flags) { virNetMessage *msg = NULL; virNetStreamHole data; int ret = -1; VIR_DEBUG("st=%p length=%llu", st, length); if (!st->allowSkip) { virReportError(VIR_ERR_OPERATION_INVALID, "%s", _("Skipping is not supported with this stream")); return -1; } memset(&data, 0, sizeof(data)); data.length = length; data.flags = flags; if (!(msg = virNetMessageNew(false))) return -1; virObjectLock(st); msg->header.prog = virNetClientProgramGetProgram(st->prog); msg->header.vers = virNetClientProgramGetVersion(st->prog); msg->header.status = VIR_NET_CONTINUE; msg->header.type = VIR_NET_STREAM_HOLE; msg->header.serial = st->serial; msg->header.proc = st->proc; virObjectUnlock(st); if (virNetMessageEncodeHeader(msg) < 0) goto cleanup; if (virNetMessageEncodePayload(msg, (xdrproc_t)xdr_virNetStreamHole, &data) < 0) goto cleanup; if (virNetClientSendStream(client, msg, st) < 0) goto cleanup; ret = 0; cleanup: virNetMessageFree(msg); return ret; } int virNetClientStreamRecvHole(virNetClient *client G_GNUC_UNUSED, virNetClientStream *st, long long *length) { if (!st->allowSkip) { virReportError(VIR_ERR_OPERATION_INVALID, "%s", _("Holes are not supported with this stream")); return -1; } virObjectLock(st); if (virNetClientStreamCheckState(st) < 0) { virObjectUnlock(st); return -1; } *length = st->holeLength; st->holeLength = 0; virObjectUnlock(st); return 0; } int virNetClientStreamEventAddCallback(virNetClientStream *st, int events, virNetClientStreamEventCallback cb, void *opaque, virFreeCallback ff) { int ret = -1; virObjectLock(st); if (st->cb) { virReportError(VIR_ERR_INTERNAL_ERROR, "%s", _("multiple stream callbacks not supported")); goto cleanup; } virObjectRef(st); if ((st->cbTimer = virEventAddTimeout(-1, virNetClientStreamEventTimer, st, virObjectUnref)) < 0) { virObjectUnref(st); goto cleanup; } st->cb = cb; st->cbOpaque = opaque; st->cbFree = ff; st->cbEvents = events; virNetClientStreamEventTimerUpdate(st); ret = 0; cleanup: virObjectUnlock(st); return ret; } int virNetClientStreamEventUpdateCallback(virNetClientStream *st, int events) { int ret = -1; virObjectLock(st); if (!st->cb) { virReportError(VIR_ERR_INTERNAL_ERROR, "%s", _("no stream callback registered")); goto cleanup; } st->cbEvents = events; virNetClientStreamEventTimerUpdate(st); ret = 0; cleanup: virObjectUnlock(st); return ret; } int virNetClientStreamEventRemoveCallback(virNetClientStream *st) { int ret = -1; virObjectLock(st); if (!st->cb) { virReportError(VIR_ERR_INTERNAL_ERROR, "%s", _("no stream callback registered")); goto cleanup; } if (!st->cbDispatch && st->cbFree) (st->cbFree)(st->cbOpaque); st->cb = NULL; st->cbOpaque = NULL; st->cbFree = NULL; st->cbEvents = 0; virEventRemoveTimeout(st->cbTimer); ret = 0; cleanup: virObjectUnlock(st); return ret; } bool virNetClientStreamEOF(virNetClientStream *st) { return st->incomingEOF; } int virNetClientStreamInData(virNetClientStream *st, int *inData, long long *length) { int ret = -1; bool msgPopped = false; virNetMessage *msg = NULL; virObjectLock(st); if (!st->allowSkip) { virReportError(VIR_ERR_OPERATION_INVALID, "%s", _("Holes are not supported with this stream")); goto cleanup; } if (virNetClientStreamCheckState(st) < 0) goto cleanup; msg = st->rx; if (!msg) { /* No incoming message. This means that the stream is at its end. In * this case, virStreamInData() should set both inData and length to * zero and return success. If there is a trailing hole though (there * shouldn't be), signal that to the caller. */ *inData = 0; *length = st->holeLength; st->holeLength = 0; } else if (msg->header.type == VIR_NET_STREAM) { *inData = 1; *length = msg->bufferLength - msg->bufferOffset; } else if (msg->header.type == VIR_NET_STREAM_HOLE) { *inData = 0; if (st->holeLength == 0) { if (virNetClientStreamHandleHole(NULL, st) < 0) goto cleanup; /* virNetClientStreamHandleHole() called above did pop the message from * the queue (and freed it). Instead of trying to push it back let's * just signal to the caller what we did. */ msgPopped = true; } *length = st->holeLength; st->holeLength = 0; } else { virReportError(VIR_ERR_INTERNAL_ERROR, _("Invalid message prog=%1$d type=%2$d serial=%3$u proc=%4$d"), msg->header.prog, msg->header.type, msg->header.serial, msg->header.proc); goto cleanup; } ret = msgPopped ? 1 : 0; cleanup: virObjectUnlock(st); return ret; }