summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJason Carey <jcarey@argv.me>2018-02-21 12:13:58 -0500
committerJason Carey <jcarey@argv.me>2018-03-08 10:56:30 -0500
commit6cfa204de9fe5a5c0f93c0ba2e0fc8f19d307b78 (patch)
tree6719bfcb24ccdd1a63629b2e3f71c091451009fa
parent707329965e7300f409694698b64ec42dd0d85e46 (diff)
downloadmongo-6cfa204de9fe5a5c0f93c0ba2e0fc8f19d307b78.tar.gz
SERVER-33572 Add ProducerConsumerQueue to util
Add a bounded, interruptible, thread safe, single producer, multi-consumer queue to the utility directory.
-rw-r--r--src/mongo/base/error_codes.err2
-rw-r--r--src/mongo/util/SConscript11
-rw-r--r--src/mongo/util/producer_consumer_queue.h569
-rw-r--r--src/mongo/util/producer_consumer_queue_test.cpp706
4 files changed, 1288 insertions, 0 deletions
diff --git a/src/mongo/base/error_codes.err b/src/mongo/base/error_codes.err
index 8e0b848421a..08a9c3aa667 100644
--- a/src/mongo/base/error_codes.err
+++ b/src/mongo/base/error_codes.err
@@ -244,6 +244,8 @@ error_code("IncompatibleWithUpgradedServer", 243)
error_code("TransactionAborted", 244)
error_code("BrokenPromise", 245)
error_code("SnapshotUnavailable", 246)
+error_code("ProducerConsumerQueueBatchTooLarge", 247)
+error_code("ProducerConsumerQueueEndClosed", 248)
# Error codes 4000-8999 are reserved.
diff --git a/src/mongo/util/SConscript b/src/mongo/util/SConscript
index 02da7744223..f3ccb536a6c 100644
--- a/src/mongo/util/SConscript
+++ b/src/mongo/util/SConscript
@@ -598,6 +598,17 @@ env.CppUnitTest(
)
env.CppUnitTest(
+ target='producer_consumer_queue_test',
+ source=[
+ 'producer_consumer_queue_test.cpp',
+ ],
+ LIBDEPS=[
+ '$BUILD_DIR/mongo/base',
+ '$BUILD_DIR/mongo/db/service_context',
+ ]
+)
+
+env.CppUnitTest(
target='duration_test',
source=[
'duration_test.cpp',
diff --git a/src/mongo/util/producer_consumer_queue.h b/src/mongo/util/producer_consumer_queue.h
new file mode 100644
index 00000000000..8f2b17d2265
--- /dev/null
+++ b/src/mongo/util/producer_consumer_queue.h
@@ -0,0 +1,569 @@
+/**
+ * Copyright (C) 2018 MongoDB Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License, version 3,
+ * as published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the GNU Affero General Public License in all respects
+ * for all of the code used other than as permitted herein. If you modify
+ * file(s) with this exception, you may extend this exception to your
+ * version of the file(s), but you are not obligated to do so. If you do not
+ * wish to do so, delete this exception statement from your version. If you
+ * delete this exception statement from all source files in the program,
+ * then also delete it in the license file.
+ */
+
+#pragma once
+
+#include <boost/optional.hpp>
+#include <deque>
+#include <list>
+#include <queue>
+#include <stack>
+
+#include "mongo/db/operation_context.h"
+#include "mongo/stdx/condition_variable.h"
+#include "mongo/stdx/mutex.h"
+#include "mongo/util/concurrency/with_lock.h"
+#include "mongo/util/scopeguard.h"
+
+namespace mongo {
+
+namespace producer_consumer_queue_detail {
+
+/**
+ * The default cost function for the producer consumer queue.
+ *
+ * By default, all items in the queue have equal weight.
+ */
+struct DefaultCostFunction {
+ template <typename T>
+ size_t operator()(const T&) const {
+ return 1;
+ }
+};
+
+// Various helpers to tighten down whether the args getting passed are valid interruption args.
+//
+// Whatever the caller passes in the interruption args, they need to be invocable on one of
+// these helpers. std::is_invocable would do the job in C++17
+constexpr std::false_type areInterruptionArgsHelper(...) {
+ return {};
+}
+
+constexpr std::true_type areInterruptionArgsHelper(OperationContext*) {
+ return {};
+}
+
+constexpr std::true_type areInterruptionArgsHelper(OperationContext*, Milliseconds) {
+ return {};
+}
+
+constexpr std::true_type areInterruptionArgsHelper(OperationContext*, Date_t) {
+ return {};
+}
+
+constexpr std::true_type areInterruptionArgsHelper(Milliseconds) {
+ return {};
+}
+
+constexpr std::true_type areInterruptionArgsHelper(Date_t) {
+ return {};
+}
+
+template <typename U, typename... InterruptionArgs>
+constexpr auto areInterruptionArgs(U&& u, InterruptionArgs&&... args) {
+ return areInterruptionArgsHelper(std::forward<U>(u), std::forward<InterruptionArgs>(args)...);
+}
+
+constexpr std::true_type areInterruptionArgs() {
+ return {};
+}
+
+} // namespace producer_consumer_queue_detail
+
+/**
+ * A bounded, blocking, thread safe, cost parametrizable, single producer, multi-consumer queue.
+ *
+ * Properties:
+ * bounded - the queue can be limited in the number of items it can hold
+ * blocking - when the queue is full, or has no entries, callers block
+ * thread safe - the queue can be accessed safely from multiple threads at the same time
+ * cost parametrizable - the cost of items in the queue need not be equal. I.e. your items could
+ * be discrete byte buffers and the queue depth measured in bytes, so that
+ * the queue could hold one large buffer, or many smaller ones
+ * single producer - Only one thread may push work into the queue
+ * multi-consumer - Any number of threads may pop work out of the queue
+ *
+ * Interruptibility:
+ * All of the blocking methods on this type allow for 6 kinds of interruptibility. The matrix is
+ * parameterized by (void|OperationContext*)|(void|Milliseconds|Date_t). These provide different
+ * kinds of waiting based on whether the method should be interruptible via opCtx, and then
+ * whether they should timeout via deadline or duration
+ *
+ * A contrived example: pcq.pop(opCtx, Minutes(1)) would be warranted if:
+ * - The caller is blocking on a client thread. (opCtx)
+ * - The caller needs to act periodically on inactivity. (the duration)
+ *
+ * Exceptions include:
+ * timeouts
+ * ErrorCodes::ExceededTimeLimit exceptions
+ * opCtx interrupts
+ * ErrorCodes::Interrupted exceptions
+ * closure of queue endpoints
+ * ErrorCodes::ProducerConsumerQueueEndClosed
+ * pushes with batches that exceed the max queue size
+ * ErrorCodes::ProducerConsumerQueueBatchTooLarge
+ *
+ * Cost Function:
+ * The cost function must have a call operator which takes a const T& and returns the cost in
+ * size_t units. It must be pure across moves for a given T and never return zero. The intent of
+ * the cost function is to express the kind of bounds the queue provides, rather than to
+ * specialize behavior for a type. I.e. you should not specialize the default cost function and
+ * the cost function should always be explicit in the type.
+ */
+template <typename T, typename CostFunc = producer_consumer_queue_detail::DefaultCostFunction>
+class ProducerConsumerQueue {
+
+public:
+ // By default the queue depth is unlimited
+ ProducerConsumerQueue()
+ : ProducerConsumerQueue(std::numeric_limits<size_t>::max(), CostFunc{}) {}
+
+ // Or it can be measured in whatever units your size function returns
+ explicit ProducerConsumerQueue(size_t size) : ProducerConsumerQueue(size, CostFunc{}) {}
+
+ // If your cost function has meaningful state, you may also pass a non-default constructed
+ // instance
+ explicit ProducerConsumerQueue(size_t size, CostFunc costFunc)
+ : _max(size), _costFunc(std::move(costFunc)) {}
+
+ ProducerConsumerQueue(const ProducerConsumerQueue&) = delete;
+ ProducerConsumerQueue& operator=(const ProducerConsumerQueue&) = delete;
+
+ ProducerConsumerQueue(ProducerConsumerQueue&&) = delete;
+ ProducerConsumerQueue& operator=(ProducerConsumerQueue&&) = delete;
+
+ ~ProducerConsumerQueue() {
+ invariant(!_producerWants);
+ invariant(!_consumers);
+ }
+
+ // Pushes the passed T into the queue
+ //
+ // Leaves T unchanged if an interrupt exception is thrown while waiting for space
+ template <
+ typename... InterruptionArgs,
+ typename = std::enable_if_t<decltype(producer_consumer_queue_detail::areInterruptionArgs(
+ std::declval<InterruptionArgs>()...))::value>>
+ void push(T&& t, InterruptionArgs&&... interruptionArgs) {
+ _pushRunner([&](stdx::unique_lock<stdx::mutex>& lk) {
+ auto cost = _invokeCostFunc(t, lk);
+ uassert(ErrorCodes::ProducerConsumerQueueBatchTooLarge,
+ str::stream() << "cost of item (" << cost
+ << ") larger than maximum queue size ("
+ << _max
+ << ")",
+ cost <= _max);
+
+ _waitForSpace(lk, cost, std::forward<InterruptionArgs>(interruptionArgs)...);
+ _push(lk, std::move(t));
+ });
+ }
+
+ // Pushes all Ts into the queue
+ //
+ // Blocks until all of the Ts can be pushed at once
+ //
+ // StartIterator must be ForwardIterator
+ //
+ // Leaves the values underneath the iterators unchanged if an interrupt exception is thrown
+ // while waiting for space
+ //
+ // Lifecycle methods of T must not throw if you want to use this method, as there's no obvious
+ // mechanism to see what was and was not pushed if those do throw
+ template <
+ typename StartIterator,
+ typename EndIterator,
+ typename... InterruptionArgs,
+ typename = std::enable_if_t<decltype(producer_consumer_queue_detail::areInterruptionArgs(
+ std::declval<InterruptionArgs>()...))::value>>
+ void pushMany(StartIterator start, EndIterator last, InterruptionArgs&&... interruptionArgs) {
+ return _pushRunner([&](stdx::unique_lock<stdx::mutex>& lk) {
+ size_t cost = 0;
+ for (auto iter = start; iter != last; ++iter) {
+ cost += _invokeCostFunc(*iter, lk);
+ }
+
+ uassert(ErrorCodes::ProducerConsumerQueueBatchTooLarge,
+ str::stream() << "cost of items in batch (" << cost
+ << ") larger than maximum queue size ("
+ << _max
+ << ")",
+ cost <= _max);
+
+ _waitForSpace(lk, cost, std::forward<InterruptionArgs>(interruptionArgs)...);
+
+ for (auto iter = start; iter != last; ++iter) {
+ _push(lk, std::move(*iter));
+ }
+ });
+ }
+
+ // Attempts a non-blocking push of a value
+ //
+ // Leaves T unchanged if it fails
+ bool tryPush(T&& t) {
+ return _pushRunner(
+ [&](stdx::unique_lock<stdx::mutex>& lk) { return _tryPush(lk, std::move(t)); });
+ }
+
+ // Pops one T out of the queue
+ template <
+ typename... InterruptionArgs,
+ typename = std::enable_if_t<decltype(producer_consumer_queue_detail::areInterruptionArgs(
+ std::declval<InterruptionArgs>()...))::value>>
+ T pop(InterruptionArgs&&... interruptionArgs) {
+ return _popRunner([&](stdx::unique_lock<stdx::mutex>& lk) {
+ _waitForNonEmpty(lk, std::forward<InterruptionArgs>(interruptionArgs)...);
+ return _pop(lk);
+ });
+ }
+
+ // Waits for at least one item in the queue, then pops items out of the queue until it would
+ // block
+ //
+ // OutputIterator must not throw on move assignment to *iter or popped values may be lost
+ // TODO: add sfinae to check to enforce
+ //
+ // Returns the cost value of the items extracted, along with the updated output iterator
+ template <
+ typename OutputIterator,
+ typename... InterruptionArgs,
+ typename = std::enable_if_t<decltype(producer_consumer_queue_detail::areInterruptionArgs(
+ std::declval<InterruptionArgs>()...))::value>>
+ std::pair<size_t, OutputIterator> popMany(OutputIterator iterator,
+ InterruptionArgs&&... interruptionArgs) {
+ return popManyUpTo(_max, iterator, std::forward<InterruptionArgs>(interruptionArgs)...);
+ }
+
+ // Waits for at least one item in the queue, then pops items out of the queue until it would
+ // block, or we've exceeded our budget
+ //
+ // OutputIterator must not throw on move assignment to *iter or popped values may be lost
+ // TODO: add sfinae to check to enforce
+ //
+ // Returns the cost value of the items extracted, along with the updated output iterator
+ template <
+ typename OutputIterator,
+ typename... InterruptionArgs,
+ typename = std::enable_if_t<decltype(producer_consumer_queue_detail::areInterruptionArgs(
+ std::declval<InterruptionArgs>()...))::value>>
+ std::pair<size_t, OutputIterator> popManyUpTo(size_t budget,
+ OutputIterator iterator,
+ InterruptionArgs&&... interruptionArgs) {
+ return _popRunner([&](stdx::unique_lock<stdx::mutex>& lk) {
+ size_t cost = 0;
+
+ _waitForNonEmpty(lk, std::forward<InterruptionArgs>(interruptionArgs)...);
+
+ while (auto out = _tryPop(lk)) {
+ cost += _invokeCostFunc(*out, lk);
+ *iterator = std::move(*out);
+ ++iterator;
+
+ if (cost >= budget) {
+ break;
+ }
+ }
+
+ return std::make_pair(cost, iterator);
+ });
+ }
+
+ // Attempts a non-blocking pop of a value
+ boost::optional<T> tryPop() {
+ return _popRunner([&](stdx::unique_lock<stdx::mutex>& lk) { return _tryPop(lk); });
+ }
+
+ // Closes the producer end. Consumers will continue to consume until the queue is exhausted, at
+ // which time they will begin to throw with an interruption dbexception
+ void closeProducerEnd() {
+ stdx::lock_guard<stdx::mutex> lk(_mutex);
+
+ _producerEndClosed = true;
+
+ _notifyIfNecessary(lk);
+ }
+
+ // Closes the consumer end. This causes all callers to throw with an interruption dbexception
+ void closeConsumerEnd() {
+ stdx::lock_guard<stdx::mutex> lk(_mutex);
+
+ _consumerEndClosed = true;
+ _producerEndClosed = true;
+
+ _notifyIfNecessary(lk);
+ }
+
+ // TEST ONLY FUNCTIONS
+
+ // Returns the current depth of the queue in CostFunction units
+ size_t sizeForTest() const {
+ stdx::lock_guard<stdx::mutex> lk(_mutex);
+
+ return _current;
+ }
+
+ // Returns true if the queue is empty
+ bool emptyForTest() const {
+ return sizeForTest() == 0;
+ }
+
+private:
+ size_t _invokeCostFunc(const T& t, WithLock) {
+ auto cost = _costFunc(t);
+ invariant(cost);
+ return cost;
+ }
+
+ void _checkProducerClosed(WithLock) {
+ uassert(
+ ErrorCodes::ProducerConsumerQueueEndClosed, "Producer end closed", !_producerEndClosed);
+ uassert(
+ ErrorCodes::ProducerConsumerQueueEndClosed, "Consumer end closed", !_consumerEndClosed);
+ }
+
+ void _checkConsumerClosed(WithLock) {
+ uassert(
+ ErrorCodes::ProducerConsumerQueueEndClosed, "Consumer end closed", !_consumerEndClosed);
+ uassert(ErrorCodes::ProducerConsumerQueueEndClosed,
+ "Producer end closed and values exhausted",
+ !(_producerEndClosed && _queue.empty()));
+ }
+
+ void _notifyIfNecessary(WithLock) {
+ // If we've closed the consumer end, or if the production end is closed and we've exhausted
+ // the queue, wake everyone up and get out of here
+ if (_consumerEndClosed || (_queue.empty() && _producerEndClosed)) {
+ if (_consumers) {
+ _condvarConsumer.notify_all();
+ }
+
+ if (_producerWants) {
+ _condvarProducer.notify_one();
+ }
+
+ return;
+ }
+
+ // If a producer is queued, and we have enough space for it to push its work
+ if (_producerWants && _current + _producerWants <= _max) {
+ _condvarProducer.notify_one();
+
+ return;
+ }
+
+ // If we have consumers and anything in the queue, notify consumers
+ if (_consumers && _queue.size()) {
+ _condvarConsumer.notify_one();
+
+ return;
+ }
+ }
+
+ template <typename Callback>
+ auto _pushRunner(Callback&& cb) {
+ stdx::unique_lock<stdx::mutex> lk(_mutex);
+
+ _checkProducerClosed(lk);
+
+ const auto guard = MakeGuard([&] { _notifyIfNecessary(lk); });
+
+ return cb(lk);
+ }
+
+ template <typename Callback>
+ auto _popRunner(Callback&& cb) {
+ stdx::unique_lock<stdx::mutex> lk(_mutex);
+
+ _checkConsumerClosed(lk);
+
+ const auto guard = MakeGuard([&] { _notifyIfNecessary(lk); });
+
+ return cb(lk);
+ }
+
+ bool _tryPush(WithLock wl, T&& t) {
+ size_t cost = _invokeCostFunc(t, wl);
+ if (_current + cost <= _max) {
+ _queue.emplace(std::move(t));
+ _current += cost;
+ return true;
+ }
+
+ return false;
+ }
+
+ void _push(WithLock wl, T&& t) {
+ size_t cost = _invokeCostFunc(t, wl);
+ invariant(_current + cost <= _max);
+
+ _queue.emplace(std::move(t));
+ _current += cost;
+ }
+
+ boost::optional<T> _tryPop(WithLock wl) {
+ boost::optional<T> out;
+
+ if (!_queue.empty()) {
+ out.emplace(std::move(_queue.front()));
+ _queue.pop();
+ _current -= _invokeCostFunc(*out, wl);
+ }
+
+ return out;
+ }
+
+ T _pop(WithLock wl) {
+ invariant(_queue.size());
+
+ auto t = std::move(_queue.front());
+ _queue.pop();
+
+ _current -= _invokeCostFunc(t, wl);
+
+ return t;
+ }
+
+ template <typename... InterruptionArgs>
+ void _waitForSpace(stdx::unique_lock<stdx::mutex>& lk,
+ size_t cost,
+ InterruptionArgs&&... interruptionArgs) {
+ invariant(!_producerWants);
+
+ _producerWants = cost;
+ const auto guard = MakeGuard([&] { _producerWants = 0; });
+
+ _waitFor(lk,
+ _condvarProducer,
+ [&] {
+ _checkProducerClosed(lk);
+ return _current + cost <= _max;
+ },
+ std::forward<InterruptionArgs>(interruptionArgs)...);
+ }
+
+ template <typename... InterruptionArgs>
+ void _waitForNonEmpty(stdx::unique_lock<stdx::mutex>& lk,
+ InterruptionArgs&&... interruptionArgs) {
+
+ _consumers++;
+ const auto guard = MakeGuard([&] { _consumers--; });
+
+ _waitFor(lk,
+ _condvarConsumer,
+ [&] {
+ _checkConsumerClosed(lk);
+ return _queue.size();
+ },
+ std::forward<InterruptionArgs>(interruptionArgs)...);
+ }
+
+ template <typename Callback>
+ void _waitFor(stdx::unique_lock<stdx::mutex>& lk,
+ stdx::condition_variable& condvar,
+ Callback&& pred,
+ OperationContext* opCtx) {
+ opCtx->waitForConditionOrInterrupt(condvar, lk, pred);
+ }
+
+ template <typename Callback>
+ void _waitFor(stdx::unique_lock<stdx::mutex>& lk,
+ stdx::condition_variable& condvar,
+ Callback&& pred) {
+ condvar.wait(lk, pred);
+ }
+
+ template <typename Callback>
+ void _waitFor(stdx::unique_lock<stdx::mutex>& lk,
+ stdx::condition_variable& condvar,
+ Callback&& pred,
+ OperationContext* opCtx,
+ Date_t deadline) {
+ uassert(ErrorCodes::ExceededTimeLimit,
+ "exceeded timeout",
+ opCtx->waitForConditionOrInterruptUntil(condvar, lk, deadline, pred));
+ }
+
+ template <typename Callback>
+ void _waitFor(stdx::unique_lock<stdx::mutex>& lk,
+ stdx::condition_variable& condvar,
+ Callback&& pred,
+ Date_t deadline) {
+ uassert(ErrorCodes::ExceededTimeLimit,
+ "exceeded timeout",
+ condvar.wait_until(lk, deadline.toSystemTimePoint(), pred));
+ }
+
+ template <typename Callback>
+ void _waitFor(stdx::unique_lock<stdx::mutex>& lk,
+ stdx::condition_variable& condvar,
+ Callback&& pred,
+ OperationContext* opCtx,
+ Milliseconds duration) {
+ uassert(ErrorCodes::ExceededTimeLimit,
+ "exceeded timeout",
+ opCtx->waitForConditionOrInterruptFor(condvar, lk, duration, pred));
+ }
+
+ template <typename Callback>
+ void _waitFor(stdx::unique_lock<stdx::mutex>& lk,
+ stdx::condition_variable& condvar,
+ Callback&& pred,
+ Milliseconds duration) {
+ uassert(ErrorCodes::ExceededTimeLimit,
+ "exceeded timeout",
+ condvar.wait_for(lk, duration.toSystemDuration(), pred));
+ }
+
+ mutable stdx::mutex _mutex;
+ stdx::condition_variable _condvarConsumer;
+ stdx::condition_variable _condvarProducer;
+
+ // Max size of the queue
+ const size_t _max;
+
+ // User's cost function
+ CostFunc _costFunc;
+
+ // Current size of the queue
+ size_t _current = 0;
+
+ std::queue<T> _queue;
+
+ // Counter for consumers in the queue
+ size_t _consumers = 0;
+
+ // Size of batch the blocking producer wants to insert
+ size_t _producerWants = 0;
+
+ // Flags that we're shutting down the queue
+ bool _consumerEndClosed = false;
+ bool _producerEndClosed = false;
+};
+
+} // namespace mongo
diff --git a/src/mongo/util/producer_consumer_queue_test.cpp b/src/mongo/util/producer_consumer_queue_test.cpp
new file mode 100644
index 00000000000..22577e791df
--- /dev/null
+++ b/src/mongo/util/producer_consumer_queue_test.cpp
@@ -0,0 +1,706 @@
+/**
+ * Copyright (C) 2018 MongoDB Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License, version 3,
+ * as published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the GNU Affero General Public License in all respects
+ * for all of the code used other than as permitted herein. If you modify
+ * file(s) with this exception, you may extend this exception to your
+ * version of the file(s), but you are not obligated to do so. If you do not
+ * wish to do so, delete this exception statement from your version. If you
+ * delete this exception statement from all source files in the program,
+ * then also delete it in the license file.
+ */
+
+#include "mongo/platform/basic.h"
+
+#include "mongo/unittest/unittest.h"
+
+#include "mongo/util/producer_consumer_queue.h"
+
+#include "mongo/db/service_context_noop.h"
+#include "mongo/stdx/condition_variable.h"
+#include "mongo/stdx/mutex.h"
+#include "mongo/stdx/thread.h"
+#include "mongo/util/assert_util.h"
+
+namespace mongo {
+
+namespace {
+
+template <typename... Args>
+class ProducerConsumerQueueTestHelper;
+
+template <>
+class ProducerConsumerQueueTestHelper<OperationContext> {
+public:
+ ProducerConsumerQueueTestHelper(ServiceContext* serviceCtx) : _serviceCtx(serviceCtx) {}
+
+ template <typename Callback>
+ stdx::thread runThread(StringData name, Callback&& cb) {
+ return stdx::thread([this, name, cb] {
+ auto client = _serviceCtx->makeClient(name.toString());
+ auto opCtx = client->makeOperationContext();
+
+ cb(opCtx.get());
+ });
+ }
+
+private:
+ ServiceContext* _serviceCtx;
+};
+
+template <typename Timeout>
+class ProducerConsumerQueueTestHelper<OperationContext, Timeout> {
+public:
+ ProducerConsumerQueueTestHelper(ServiceContext* serviceCtx, Timeout timeout)
+ : _serviceCtx(serviceCtx), _timeout(timeout) {}
+
+ template <typename Callback>
+ stdx::thread runThread(StringData name, Callback&& cb) {
+ return stdx::thread([this, name, cb] {
+ auto client = _serviceCtx->makeClient(name.toString());
+ auto opCtx = client->makeOperationContext();
+
+ cb(opCtx.get(), _timeout);
+ });
+ }
+
+private:
+ ServiceContext* _serviceCtx;
+ Timeout _timeout;
+};
+
+template <>
+class ProducerConsumerQueueTestHelper<> {
+public:
+ ProducerConsumerQueueTestHelper() = default;
+
+ template <typename Callback>
+ stdx::thread runThread(StringData name, Callback&& cb) {
+ return stdx::thread([this, name, cb] { cb(); });
+ }
+};
+
+template <typename Timeout>
+class ProducerConsumerQueueTestHelper<Timeout> {
+public:
+ ProducerConsumerQueueTestHelper(Timeout timeout) : _timeout(timeout) {}
+
+ template <typename Callback>
+ stdx::thread runThread(StringData name, Callback&& cb) {
+ return stdx::thread([this, name, cb] { cb(_timeout); });
+ }
+
+private:
+ Timeout _timeout;
+};
+
+class ProducerConsumerQueueTest : public unittest::Test {
+public:
+ ProducerConsumerQueueTest() : _serviceCtx(stdx::make_unique<ServiceContextNoop>()) {}
+
+ template <typename Callback>
+ stdx::thread runThread(StringData name, Callback&& cb) {
+ return stdx::thread([this, name, cb] {
+ auto client = _serviceCtx->makeClient(name.toString());
+ auto opCtx = client->makeOperationContext();
+
+ cb(opCtx.get());
+ });
+ }
+
+ template <typename Callback>
+ void runPermutations(Callback&& callback) {
+ const Minutes duration(30);
+
+ callback(ProducerConsumerQueueTestHelper<OperationContext>(_serviceCtx.get()));
+ callback(ProducerConsumerQueueTestHelper<OperationContext, Milliseconds>(_serviceCtx.get(),
+ duration));
+ callback(ProducerConsumerQueueTestHelper<OperationContext, Date_t>(
+ _serviceCtx.get(), _serviceCtx->getPreciseClockSource()->now() + duration));
+ callback(ProducerConsumerQueueTestHelper<>());
+ callback(ProducerConsumerQueueTestHelper<Milliseconds>(duration));
+ callback(ProducerConsumerQueueTestHelper<Date_t>(
+ _serviceCtx->getPreciseClockSource()->now() + duration));
+ }
+
+ template <typename Callback>
+ void runTimeoutPermutations(Callback&& callback) {
+ const Milliseconds duration(10);
+
+ callback(ProducerConsumerQueueTestHelper<OperationContext, Milliseconds>(_serviceCtx.get(),
+ duration));
+ callback(ProducerConsumerQueueTestHelper<OperationContext, Date_t>(
+ _serviceCtx.get(), _serviceCtx->getPreciseClockSource()->now() + duration));
+ callback(ProducerConsumerQueueTestHelper<Milliseconds>(duration));
+ callback(ProducerConsumerQueueTestHelper<Date_t>(
+ _serviceCtx->getPreciseClockSource()->now() + duration));
+ }
+
+private:
+ std::unique_ptr<ServiceContext> _serviceCtx;
+};
+
+class MoveOnly {
+public:
+ struct CostFunc {
+ CostFunc() = default;
+ explicit CostFunc(size_t val) : val(val) {}
+
+ size_t operator()(const MoveOnly& mo) const {
+ return val + *mo._val;
+ }
+
+ const size_t val = 0;
+ };
+
+ explicit MoveOnly(int i) : _val(i) {}
+
+ MoveOnly(const MoveOnly&) = delete;
+ MoveOnly& operator=(const MoveOnly&) = delete;
+
+ MoveOnly(MoveOnly&& other) : _val(other._val) {
+ other._val.reset();
+ }
+
+ MoveOnly& operator=(MoveOnly&& other) {
+ if (&other == this) {
+ return *this;
+ }
+
+ _val = other._val;
+ other._val.reset();
+
+ return *this;
+ }
+
+ bool movedFrom() const {
+ return !_val;
+ }
+
+ friend bool operator==(const MoveOnly& lhs, const MoveOnly& rhs) {
+ return *lhs._val == *rhs._val;
+ }
+
+ friend bool operator!=(const MoveOnly& lhs, const MoveOnly& rhs) {
+ return !(lhs == rhs);
+ }
+
+ friend std::ostream& operator<<(std::ostream& os, const MoveOnly& mo) {
+ return (os << "MoveOnly(" << *mo._val << ")");
+ }
+
+private:
+ boost::optional<int> _val;
+};
+
+TEST_F(ProducerConsumerQueueTest, basicPushPop) {
+ runPermutations([](auto helper) {
+ ProducerConsumerQueue<MoveOnly> pcq{};
+
+ helper
+ .runThread(
+ "Producer",
+ [&](auto... interruptionArgs) { pcq.push(MoveOnly(1), interruptionArgs...); })
+ .join();
+
+ ASSERT_EQUALS(pcq.sizeForTest(), 1ul);
+
+ helper
+ .runThread("Consumer",
+ [&](auto... interruptionArgs) {
+ ASSERT_EQUALS(pcq.pop(interruptionArgs...), MoveOnly(1));
+ })
+ .join();
+
+ ASSERT_TRUE(pcq.emptyForTest());
+ });
+}
+
+TEST_F(ProducerConsumerQueueTest, closeConsumerEnd) {
+ runPermutations([](auto helper) {
+ ProducerConsumerQueue<MoveOnly> pcq{1};
+
+ pcq.push(MoveOnly(1));
+
+ auto producer = helper.runThread("Producer", [&](auto... interruptionArgs) {
+ ASSERT_THROWS_CODE(pcq.push(MoveOnly(2), interruptionArgs...),
+ DBException,
+ ErrorCodes::ProducerConsumerQueueEndClosed);
+ });
+
+ ASSERT_EQUALS(pcq.sizeForTest(), 1ul);
+
+ pcq.closeConsumerEnd();
+
+ ASSERT_THROWS_CODE(pcq.pop(), DBException, ErrorCodes::ProducerConsumerQueueEndClosed);
+
+ producer.join();
+ });
+}
+
+TEST_F(ProducerConsumerQueueTest, closeProducerEndImmediate) {
+ runPermutations([](auto helper) {
+ ProducerConsumerQueue<MoveOnly> pcq{};
+
+ pcq.push(MoveOnly(1));
+ pcq.closeProducerEnd();
+
+ helper
+ .runThread("Consumer",
+ [&](auto... interruptionArgs) {
+ ASSERT_EQUALS(pcq.pop(interruptionArgs...), MoveOnly(1));
+
+ ASSERT_THROWS_CODE(pcq.pop(interruptionArgs...),
+ DBException,
+ ErrorCodes::ProducerConsumerQueueEndClosed);
+ })
+ .join();
+
+ });
+}
+
+TEST_F(ProducerConsumerQueueTest, closeProducerEndBlocking) {
+ runPermutations([](auto helper) {
+ ProducerConsumerQueue<MoveOnly> pcq{};
+
+ auto consumer = helper.runThread("Consumer", [&](auto... interruptionArgs) {
+ ASSERT_THROWS_CODE(pcq.pop(interruptionArgs...),
+ DBException,
+ ErrorCodes::ProducerConsumerQueueEndClosed);
+ });
+
+ pcq.closeProducerEnd();
+
+ consumer.join();
+ });
+}
+
+TEST_F(ProducerConsumerQueueTest, popsWithTimeout) {
+ runTimeoutPermutations([](auto helper) {
+ ProducerConsumerQueue<MoveOnly> pcq{};
+
+ helper
+ .runThread(
+ "Consumer",
+ [&](auto... interruptionArgs) {
+ ASSERT_THROWS_CODE(
+ pcq.pop(interruptionArgs...), DBException, ErrorCodes::ExceededTimeLimit);
+
+ std::vector<MoveOnly> vec;
+ ASSERT_THROWS_CODE(pcq.popMany(std::back_inserter(vec), interruptionArgs...),
+ DBException,
+ ErrorCodes::ExceededTimeLimit);
+
+ ASSERT_THROWS_CODE(
+ pcq.popManyUpTo(1000, std::back_inserter(vec), interruptionArgs...),
+ DBException,
+ ErrorCodes::ExceededTimeLimit);
+ })
+ .join();
+
+ ASSERT_EQUALS(pcq.sizeForTest(), 0ul);
+ });
+}
+
+TEST_F(ProducerConsumerQueueTest, pushesWithTimeout) {
+ runTimeoutPermutations([](auto helper) {
+ ProducerConsumerQueue<MoveOnly> pcq{1};
+
+ {
+ MoveOnly mo(1);
+ pcq.push(std::move(mo));
+ ASSERT(mo.movedFrom());
+ }
+
+ helper
+ .runThread("Consumer",
+ [&](auto... interruptionArgs) {
+ {
+ MoveOnly mo(2);
+ ASSERT_THROWS_CODE(pcq.push(std::move(mo), interruptionArgs...),
+ DBException,
+ ErrorCodes::ExceededTimeLimit);
+ ASSERT_EQUALS(pcq.sizeForTest(), 1ul);
+ ASSERT(!mo.movedFrom());
+ ASSERT_EQUALS(mo, MoveOnly(2));
+ }
+
+ {
+ std::vector<MoveOnly> vec;
+ vec.emplace_back(MoveOnly(2));
+
+ auto iter = begin(vec);
+ ASSERT_THROWS_CODE(pcq.pushMany(iter, end(vec), interruptionArgs...),
+ DBException,
+ ErrorCodes::ExceededTimeLimit);
+ ASSERT_EQUALS(pcq.sizeForTest(), 1ul);
+ ASSERT(!vec[0].movedFrom());
+ ASSERT_EQUALS(vec[0], MoveOnly(2));
+ }
+ })
+ .join();
+
+ ASSERT_EQUALS(pcq.sizeForTest(), 1ul);
+ });
+}
+
+TEST_F(ProducerConsumerQueueTest, basicPushPopWithBlocking) {
+ runPermutations([](auto helper) {
+ ProducerConsumerQueue<MoveOnly> pcq{};
+
+ auto consumer = helper.runThread("Consumer", [&](auto... interruptionArgs) {
+ ASSERT_EQUALS(pcq.pop(interruptionArgs...), MoveOnly(1));
+ });
+
+ auto producer = helper.runThread("Producer", [&](auto... interruptionArgs) {
+ pcq.push(MoveOnly(1), interruptionArgs...);
+ });
+
+ consumer.join();
+ producer.join();
+
+ ASSERT_TRUE(pcq.emptyForTest());
+ });
+}
+
+TEST_F(ProducerConsumerQueueTest, multipleStepPushPopWithBlocking) {
+ runPermutations([](auto helper) {
+ ProducerConsumerQueue<MoveOnly> pcq{1};
+
+ auto consumer = helper.runThread("Consumer", [&](auto... interruptionArgs) {
+ for (int i = 0; i < 10; ++i) {
+ ASSERT_EQUALS(pcq.pop(interruptionArgs...), MoveOnly(i));
+ }
+ });
+
+ auto producer = helper.runThread("Producer", [&](auto... interruptionArgs) {
+ for (int i = 0; i < 10; ++i) {
+ pcq.push(MoveOnly(i), interruptionArgs...);
+ }
+ });
+
+ consumer.join();
+ producer.join();
+
+ ASSERT_TRUE(pcq.emptyForTest());
+ });
+}
+
+
+TEST_F(ProducerConsumerQueueTest, pushTooLarge) {
+ runPermutations([](auto helper) {
+ {
+ ProducerConsumerQueue<MoveOnly, MoveOnly::CostFunc> pcq{1};
+
+ helper
+ .runThread("Producer",
+ [&](auto... interruptionArgs) {
+ ASSERT_THROWS_CODE(pcq.push(MoveOnly(2), interruptionArgs...),
+ DBException,
+ ErrorCodes::ProducerConsumerQueueBatchTooLarge);
+ })
+ .join();
+ }
+
+ {
+ ProducerConsumerQueue<MoveOnly, MoveOnly::CostFunc> pcq{4};
+
+ std::vector<MoveOnly> vec;
+ vec.push_back(MoveOnly(3));
+ vec.push_back(MoveOnly(3));
+
+ helper
+ .runThread("Producer",
+ [&](auto... interruptionArgs) {
+ ASSERT_THROWS_CODE(
+ pcq.pushMany(begin(vec), end(vec), interruptionArgs...),
+ DBException,
+ ErrorCodes::ProducerConsumerQueueBatchTooLarge);
+ })
+ .join();
+ }
+ });
+}
+
+TEST_F(ProducerConsumerQueueTest, pushManyPopWithoutBlocking) {
+ runPermutations([](auto helper) {
+ ProducerConsumerQueue<MoveOnly> pcq{};
+
+ helper
+ .runThread("Producer",
+ [&](auto... interruptionArgs) {
+ std::vector<MoveOnly> vec;
+ for (int i = 0; i < 10; ++i) {
+ vec.emplace_back(MoveOnly(i));
+ }
+
+ pcq.pushMany(begin(vec), end(vec), interruptionArgs...);
+ })
+ .join();
+
+ helper
+ .runThread("Consumer",
+ [&](auto... interruptionArgs) {
+ for (int i = 0; i < 10; ++i) {
+ ASSERT_EQUALS(pcq.pop(interruptionArgs...), MoveOnly(i));
+ }
+ })
+ .join();
+
+ ASSERT_TRUE(pcq.emptyForTest());
+ });
+}
+
+TEST_F(ProducerConsumerQueueTest, popManyPopWithBlocking) {
+ runPermutations([](auto helper) {
+ ProducerConsumerQueue<MoveOnly> pcq{2};
+
+ auto consumer = helper.runThread("Consumer", [&](auto... interruptionArgs) {
+ for (int i = 0; i < 10; i = i + 2) {
+ std::vector<MoveOnly> out;
+
+ pcq.popMany(std::back_inserter(out), interruptionArgs...);
+
+ ASSERT_EQUALS(out.size(), 2ul);
+ ASSERT_EQUALS(out[0], MoveOnly(i));
+ ASSERT_EQUALS(out[1], MoveOnly(i + 1));
+ }
+ });
+
+ auto producer = helper.runThread("Producer", [&](auto... interruptionArgs) {
+ std::vector<MoveOnly> vec;
+ for (int i = 0; i < 10; ++i) {
+ vec.emplace_back(MoveOnly(i));
+ }
+
+ for (auto iter = begin(vec); iter != end(vec); iter += 2) {
+ pcq.pushMany(iter, iter + 2);
+ }
+ });
+
+ consumer.join();
+ producer.join();
+
+ ASSERT_TRUE(pcq.emptyForTest());
+ });
+}
+
+TEST_F(ProducerConsumerQueueTest, popManyUpToPopWithBlocking) {
+ runPermutations([](auto helper) {
+ ProducerConsumerQueue<MoveOnly> pcq{4};
+
+ auto consumer = helper.runThread("Consumer", [&](auto... interruptionArgs) {
+ for (int i = 0; i < 10; i = i + 2) {
+ std::vector<MoveOnly> out;
+
+ size_t spent;
+ std::tie(spent, std::ignore) =
+ pcq.popManyUpTo(2, std::back_inserter(out), interruptionArgs...);
+
+ ASSERT_EQUALS(spent, 2ul);
+ ASSERT_EQUALS(out.size(), 2ul);
+ ASSERT_EQUALS(out[0], MoveOnly(i));
+ ASSERT_EQUALS(out[1], MoveOnly(i + 1));
+ }
+ });
+
+ auto producer = helper.runThread("Producer", [&](auto... interruptionArgs) {
+ std::vector<MoveOnly> vec;
+ for (int i = 0; i < 10; ++i) {
+ vec.emplace_back(MoveOnly(i));
+ }
+
+ for (auto iter = begin(vec); iter != end(vec); iter += 2) {
+ pcq.pushMany(iter, iter + 2);
+ }
+ });
+
+ consumer.join();
+ producer.join();
+
+ ASSERT_TRUE(pcq.emptyForTest());
+ });
+}
+
+TEST_F(ProducerConsumerQueueTest, popManyUpToPopWithBlockingWithSpecialCost) {
+ runPermutations([](auto helper) {
+ ProducerConsumerQueue<MoveOnly, MoveOnly::CostFunc> pcq{};
+
+ auto consumer = helper.runThread("Consumer", [&](auto... interruptionArgs) {
+ {
+ std::vector<MoveOnly> out;
+ size_t spent;
+ std::tie(spent, std::ignore) =
+ pcq.popManyUpTo(5, std::back_inserter(out), interruptionArgs...);
+
+ ASSERT_EQUALS(spent, 6ul);
+ ASSERT_EQUALS(out.size(), 3ul);
+ ASSERT_EQUALS(out[0], MoveOnly(1));
+ ASSERT_EQUALS(out[1], MoveOnly(2));
+ ASSERT_EQUALS(out[2], MoveOnly(3));
+ }
+
+ {
+ std::vector<MoveOnly> out;
+ size_t spent;
+ std::tie(spent, std::ignore) =
+ pcq.popManyUpTo(15, std::back_inserter(out), interruptionArgs...);
+
+ ASSERT_EQUALS(spent, 9ul);
+ ASSERT_EQUALS(out.size(), 2ul);
+ ASSERT_EQUALS(out[0], MoveOnly(4));
+ ASSERT_EQUALS(out[1], MoveOnly(5));
+ }
+ });
+
+ auto producer = helper.runThread("Producer", [&](auto... interruptionArgs) {
+ std::vector<MoveOnly> vec;
+ for (int i = 1; i < 6; ++i) {
+ vec.emplace_back(MoveOnly(i));
+ }
+
+ pcq.pushMany(begin(vec), end(vec), interruptionArgs...);
+ });
+
+ consumer.join();
+ producer.join();
+
+ ASSERT_TRUE(pcq.emptyForTest());
+ });
+}
+
+TEST_F(ProducerConsumerQueueTest, singleProducerMultiConsumer) {
+ runPermutations([](auto helper) {
+ ProducerConsumerQueue<MoveOnly> pcq{};
+
+ stdx::mutex mutex;
+ size_t success = 0;
+ size_t failure = 0;
+
+ std::array<stdx::thread, 3> threads;
+ for (auto& thread : threads) {
+ thread = helper.runThread("Consumer", [&](auto... interruptionArgs) {
+ {
+ try {
+ pcq.pop(interruptionArgs...);
+ stdx::lock_guard<stdx::mutex> lk(mutex);
+ success++;
+ } catch (const ExceptionFor<ErrorCodes::ProducerConsumerQueueEndClosed>&) {
+ stdx::lock_guard<stdx::mutex> lk(mutex);
+ failure++;
+ }
+ }
+ });
+ }
+
+ pcq.push(MoveOnly(1));
+ pcq.push(MoveOnly(2));
+
+ pcq.closeProducerEnd();
+
+ for (auto& thread : threads) {
+ thread.join();
+ }
+
+ ASSERT_EQUALS(success, 2ul);
+ ASSERT_EQUALS(failure, 1ul);
+
+ ASSERT_TRUE(pcq.emptyForTest());
+ });
+}
+
+TEST_F(ProducerConsumerQueueTest, basicTryPop) {
+ ProducerConsumerQueue<MoveOnly> pcq{};
+
+ ASSERT_FALSE(pcq.tryPop());
+ ASSERT_TRUE(pcq.tryPush(MoveOnly(1)));
+ ASSERT_EQUALS(pcq.sizeForTest(), 1ul);
+
+ auto val = pcq.tryPop();
+
+ ASSERT_FALSE(pcq.tryPop());
+ ASSERT_TRUE(val);
+ ASSERT_EQUALS(*val, MoveOnly(1));
+
+ ASSERT_TRUE(pcq.emptyForTest());
+}
+
+TEST_F(ProducerConsumerQueueTest, basicTryPush) {
+ ProducerConsumerQueue<MoveOnly> pcq{1};
+
+ ASSERT_TRUE(pcq.tryPush(MoveOnly(1)));
+ ASSERT_FALSE(pcq.tryPush(MoveOnly(2)));
+
+ ASSERT_EQUALS(pcq.sizeForTest(), 1ul);
+
+ auto val = pcq.tryPop();
+ ASSERT_FALSE(pcq.tryPop());
+ ASSERT_TRUE(val);
+ ASSERT_EQUALS(*val, MoveOnly(1));
+
+ ASSERT_TRUE(pcq.emptyForTest());
+}
+
+TEST_F(ProducerConsumerQueueTest, tryPushWithSpecialCost) {
+ ProducerConsumerQueue<MoveOnly, MoveOnly::CostFunc> pcq{5};
+
+ ASSERT_TRUE(pcq.tryPush(MoveOnly(1)));
+ ASSERT_TRUE(pcq.tryPush(MoveOnly(2)));
+ ASSERT_FALSE(pcq.tryPush(MoveOnly(3)));
+
+ ASSERT_EQUALS(pcq.sizeForTest(), 3ul);
+
+ auto val1 = pcq.tryPop();
+ ASSERT_EQUALS(pcq.sizeForTest(), 2ul);
+ auto val2 = pcq.tryPop();
+ ASSERT_EQUALS(pcq.sizeForTest(), 0ul);
+ ASSERT_FALSE(pcq.tryPop());
+ ASSERT_TRUE(val1);
+ ASSERT_TRUE(val2);
+ ASSERT_EQUALS(*val1, MoveOnly(1));
+ ASSERT_EQUALS(*val2, MoveOnly(2));
+
+ ASSERT_TRUE(pcq.emptyForTest());
+}
+
+TEST_F(ProducerConsumerQueueTest, tryPushWithSpecialStatefulCost) {
+ ProducerConsumerQueue<MoveOnly, MoveOnly::CostFunc> pcq{5, MoveOnly::CostFunc(1)};
+
+ ASSERT_TRUE(pcq.tryPush(MoveOnly(1)));
+ ASSERT_TRUE(pcq.tryPush(MoveOnly(2)));
+ ASSERT_FALSE(pcq.tryPush(MoveOnly(3)));
+
+ ASSERT_EQUALS(pcq.sizeForTest(), 5ul);
+
+ auto val1 = pcq.tryPop();
+ ASSERT_EQUALS(pcq.sizeForTest(), 3ul);
+ auto val2 = pcq.tryPop();
+ ASSERT_EQUALS(pcq.sizeForTest(), 0ul);
+ ASSERT_FALSE(pcq.tryPop());
+ ASSERT_TRUE(val1);
+ ASSERT_TRUE(val2);
+ ASSERT_EQUALS(*val1, MoveOnly(1));
+ ASSERT_EQUALS(*val2, MoveOnly(2));
+
+ ASSERT_TRUE(pcq.emptyForTest());
+}
+
+} // namespace
+
+} // namespace mongo