-
Notifications
You must be signed in to change notification settings - Fork 0
/
ThreadPool.h
163 lines (134 loc) · 5.62 KB
/
ThreadPool.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
#ifndef __THREADPOOL_H_
#define __THREADPOOL_H_
#include <iostream>
#include <vector>
#include <queue>
#include <condition_variable>
#include <functional>
#include <future>
#include <memory>
class ThreadPool {
private:
std::vector<std::thread> workers;
std::queue<std::function<void()>> tasks;
std::mutex queue_mutex;
std::condition_variable condition;
bool stop;
public:
ThreadPool(const size_t threads) noexcept; // Constructor
~ThreadPool() noexcept; // Destructor
void resize(const size_t threads) noexcept; // Resizes the thread pool
void clear(void) noexcept; // Clears the queued tasks that have NOT been started yet
// Delete these functions as this class is NOT safe to be copied
ThreadPool(const ThreadPool&) = delete;
ThreadPool& operator=(const ThreadPool&) = delete;
// Task enqueueing function, returns an std::future
template <typename F, typename... Args>
auto enqueue(F&& f, Args&&... args) -> std::future<decltype(f(args...))> {
if (!workers.size()) {
throw std::runtime_error("enqueue called on empty ThreadPool");
}
using return_type = decltype(f(args...));
std::shared_ptr<std::packaged_task<return_type()>> task = std::make_shared<std::packaged_task<return_type()>>(std::bind(std::forward<F>(f), std::forward<Args>(args)...));
std::future<return_type> res = task->get_future();
{
std::unique_lock<std::mutex> lock(queue_mutex);
if (stop)
throw std::runtime_error("enqueue called on stopped ThreadPool");
tasks.emplace([task]() {
(*task)();
});
}
condition.notify_one();
return res;
}
}; // class ThreadPool
inline void ThreadPool::clear(void) noexcept {
std::queue<std::function<void()>> swapQ;
tasks.swap(swapQ); // Swap with the empty queue
}
inline ThreadPool::ThreadPool(const size_t threads) noexcept : stop(false) {
workers.resize(threads); // Reserve for below loop
for (auto& thread : workers) {
thread = std::thread([this] {
while (true) {
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(queue_mutex);
// Wait until there is a task in the queue or the thread pool is stopped (via destructor)
condition.wait(lock, [this] { return stop || !tasks.empty(); });
if (stop && tasks.empty()) return;
// Pop the next task from the queue.
task = std::move(tasks.front());
tasks.pop();
} // Let std::unique_lock destructor be called, releasing lock
try {
task();
} catch (const std::exception& e) {
std::cerr << "Caught exception in thread pool task: " << e.what() << std::endl; // endl to flush output buffer
// TODO: Add a way to notify the main thread (or any passed callback function) when exceptions in the threads happen
} catch (...) {
// In case a non-std exception is thrown
std::cerr << "Caught non-standard exception in thread pool task" << std::endl;
}
}
});
}
}
inline ThreadPool::~ThreadPool() noexcept {
stop = true;
condition.notify_all();
for (auto& thread : workers) {
if (thread.joinable()) { // Check shouldn't ever be false, but call just in case to guard against UB
thread.join();
}
}
}
inline void ThreadPool::resize(const size_t threads) noexcept {
// Lock the queue_mutex for the entirety of this resize function.
std::unique_lock<std::mutex> lock(queue_mutex);
size_t sz = workers.size();
if (threads < sz) {
// Need to shrink
stop = true;
for (size_t i = threads; i < sz; ++i) {
condition.notify_one(); // Notify the number we need to lose with stop=true, then set stop back to false and join the to-be-removed threads
}
stop = false;
for (auto it = std::next(workers.begin(), threads); it != workers.end(); ++it) {
auto& thread = *it;
if (thread.joinable()) { // Guard against UB
thread.join(); // Join the threads to be removed
}
}
workers.erase(workers.begin() + threads, workers.end()); // And finally erase them
return;
}
if (threads > sz) {
// Need to grow
workers.resize(threads);
for (auto it = std::next(workers.begin(), sz); it != workers.end(); ++it) {
auto& thread = *it;
thread = std::thread([this] {
while (true) {
std::function<void()> task;
{
std::unique_lock<std::mutex> queueLock(queue_mutex);
condition.wait(queueLock, [this] { return stop || !tasks.empty(); });
if (stop && tasks.empty()) return;
task = std::move(tasks.front());
tasks.pop();
}
try {
task();
} catch (const std::exception& e) {
std::cerr << "Caught exception in thread pool task: " << e.what() << std::endl;
} catch (...) {
std::cerr << "Caught non-standard exception in thread pool task" << std::endl;
}
}
});
}
}
}
#endif //__THREADPOOL_H_