Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(141)

Side by Side Diff: webrtc/base/task_queue_win.cc

Issue 2750853002: TaskQueue[Win] DOS handling (Closed)
Patch Set: Cleanup Created 3 years, 9 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View unified diff | Download patch
« no previous file with comments | « webrtc/base/task_queue.h ('k') | no next file » | no next file with comments »
Toggle Intra-line Diffs ('i') | Expand Comments ('e') | Collapse Comments ('c') | Show Comments Hide Comments ('s')
OLDNEW
1 /* 1 /*
2 * Copyright 2016 The WebRTC Project Authors. All rights reserved. 2 * Copyright 2016 The WebRTC Project Authors. All rights reserved.
3 * 3 *
4 * Use of this source code is governed by a BSD-style license 4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source 5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found 6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may 7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree. 8 * be found in the AUTHORS file in the root of the source tree.
9 */ 9 */
10 10
11 #include "webrtc/base/task_queue.h" 11 #include "webrtc/base/task_queue.h"
12 12
13 #include <mmsystem.h> 13 #include <mmsystem.h>
14 #include <string.h> 14 #include <string.h>
15 15
16 #include <algorithm> 16 #include <algorithm>
17 #include <queue> 17 #include <queue>
18 18
19 #include "webrtc/base/arraysize.h"
19 #include "webrtc/base/checks.h" 20 #include "webrtc/base/checks.h"
20 #include "webrtc/base/logging.h" 21 #include "webrtc/base/logging.h"
21 #include "webrtc/base/safe_conversions.h" 22 #include "webrtc/base/safe_conversions.h"
22 #include "webrtc/base/timeutils.h" 23 #include "webrtc/base/timeutils.h"
23 24
24 namespace rtc { 25 namespace rtc {
25 namespace { 26 namespace {
26 #define WM_RUN_TASK WM_USER + 1 27 #define WM_RUN_TASK WM_USER + 1
27 #define WM_QUEUE_DELAYED_TASK WM_USER + 2 28 #define WM_QUEUE_DELAYED_TASK WM_USER + 2
28 29
(...skipping 81 matching lines...) Expand 10 before | Expand all | Expand 10 after
110 // There are two basic workarounds, one using const_cast, which would also 111 // There are two basic workarounds, one using const_cast, which would also
111 // make the key (|due_time|), non-const and the other is to make the non-key 112 // make the key (|due_time|), non-const and the other is to make the non-key
112 // (|task|), mutable. 113 // (|task|), mutable.
113 // Because of this, the |task| variable is made private and can only be 114 // Because of this, the |task| variable is made private and can only be
114 // mutated by calling the |Run()| method. 115 // mutated by calling the |Run()| method.
115 mutable std::unique_ptr<QueuedTask> task_; 116 mutable std::unique_ptr<QueuedTask> task_;
116 }; 117 };
117 118
118 class MultimediaTimer { 119 class MultimediaTimer {
119 public: 120 public:
120 MultimediaTimer() : event_(::CreateEvent(nullptr, false, false, nullptr)) {} 121 // Note: We create an event that requires manual reset.
122 MultimediaTimer() : event_(::CreateEvent(nullptr, true, false, nullptr)) {}
121 123
122 ~MultimediaTimer() { 124 ~MultimediaTimer() {
123 Cancel(); 125 Cancel();
124 ::CloseHandle(event_); 126 ::CloseHandle(event_);
125 } 127 }
126 128
127 bool StartOneShotTimer(UINT delay_ms) { 129 bool StartOneShotTimer(UINT delay_ms) {
128 RTC_DCHECK_EQ(0, timer_id_); 130 RTC_DCHECK_EQ(0, timer_id_);
129 RTC_DCHECK(event_ != nullptr); 131 RTC_DCHECK(event_ != nullptr);
130 timer_id_ = 132 timer_id_ =
131 ::timeSetEvent(delay_ms, 0, reinterpret_cast<LPTIMECALLBACK>(event_), 0, 133 ::timeSetEvent(delay_ms, 0, reinterpret_cast<LPTIMECALLBACK>(event_), 0,
132 TIME_ONESHOT | TIME_CALLBACK_EVENT_SET); 134 TIME_ONESHOT | TIME_CALLBACK_EVENT_SET);
133 return timer_id_ != 0; 135 return timer_id_ != 0;
134 } 136 }
135 137
136 void Cancel() { 138 void Cancel() {
139 ::ResetEvent(event_);
137 if (timer_id_) { 140 if (timer_id_) {
138 ::timeKillEvent(timer_id_); 141 ::timeKillEvent(timer_id_);
139 timer_id_ = 0; 142 timer_id_ = 0;
140 } 143 }
141 } 144 }
142 145
143 HANDLE* event_for_wait() { return &event_; } 146 HANDLE* event_for_wait() { return &event_; }
144 147
145 private: 148 private:
146 HANDLE event_ = nullptr; 149 HANDLE event_ = nullptr;
147 MMRESULT timer_id_ = 0; 150 MMRESULT timer_id_ = 0;
148 151
149 RTC_DISALLOW_COPY_AND_ASSIGN(MultimediaTimer); 152 RTC_DISALLOW_COPY_AND_ASSIGN(MultimediaTimer);
150 }; 153 };
151 154
152 } // namespace 155 } // namespace
153 156
154 class TaskQueue::ThreadState { 157 class TaskQueue::ThreadState {
155 public: 158 public:
156 ThreadState() {} 159 explicit ThreadState(HANDLE in_queue) : in_queue_(in_queue) {}
157 ~ThreadState() {} 160 ~ThreadState() {}
158 161
159 void RunThreadMain(); 162 void RunThreadMain();
160 163
161 private: 164 private:
162 bool ProcessQueuedMessages(); 165 bool ProcessQueuedMessages();
163 void RunDueTasks(); 166 void RunDueTasks();
164 void ScheduleNextTimer(); 167 void ScheduleNextTimer();
165 void CancelTimers(); 168 void CancelTimers();
166 169
167 // Since priority_queue<> by defult orders items in terms of 170 // Since priority_queue<> by defult orders items in terms of
168 // largest->smallest, using std::less<>, and we want smallest->largest, 171 // largest->smallest, using std::less<>, and we want smallest->largest,
169 // we would like to use std::greater<> here. Alas it's only available in 172 // we would like to use std::greater<> here. Alas it's only available in
170 // C++14 and later, so we roll our own compare template that that relies on 173 // C++14 and later, so we roll our own compare template that that relies on
171 // operator<(). 174 // operator<().
172 template <typename T> 175 template <typename T>
173 struct greater { 176 struct greater {
174 bool operator()(const T& l, const T& r) { return l > r; } 177 bool operator()(const T& l, const T& r) { return l > r; }
175 }; 178 };
176 179
177 MultimediaTimer timer_; 180 MultimediaTimer timer_;
178 std::priority_queue<DelayedTaskInfo, 181 std::priority_queue<DelayedTaskInfo,
179 std::vector<DelayedTaskInfo>, 182 std::vector<DelayedTaskInfo>,
180 greater<DelayedTaskInfo>> 183 greater<DelayedTaskInfo>>
181 timer_tasks_; 184 timer_tasks_;
182 UINT_PTR timer_id_ = 0; 185 UINT_PTR timer_id_ = 0;
186 HANDLE in_queue_;
183 }; 187 };
184 188
185 TaskQueue::TaskQueue(const char* queue_name, Priority priority /*= NORMAL*/) 189 TaskQueue::TaskQueue(const char* queue_name, Priority priority /*= NORMAL*/)
186 : thread_(&TaskQueue::ThreadMain, 190 : thread_(&TaskQueue::ThreadMain,
187 this, 191 this,
188 queue_name, 192 queue_name,
189 TaskQueuePriorityToThreadPriority(priority)) { 193 TaskQueuePriorityToThreadPriority(priority)),
194 in_queue_(::CreateEvent(nullptr, true, false, nullptr)) {
190 RTC_DCHECK(queue_name); 195 RTC_DCHECK(queue_name);
196 RTC_DCHECK(in_queue_);
191 thread_.Start(); 197 thread_.Start();
192 Event event(false, false); 198 Event event(false, false);
193 ThreadStartupData startup = {&event, this}; 199 ThreadStartupData startup = {&event, this};
194 RTC_CHECK(thread_.QueueAPC(&InitializeQueueThread, 200 RTC_CHECK(thread_.QueueAPC(&InitializeQueueThread,
195 reinterpret_cast<ULONG_PTR>(&startup))); 201 reinterpret_cast<ULONG_PTR>(&startup)));
196 event.Wait(Event::kForever); 202 event.Wait(Event::kForever);
197 } 203 }
198 204
199 TaskQueue::~TaskQueue() { 205 TaskQueue::~TaskQueue() {
200 RTC_DCHECK(!IsCurrent()); 206 RTC_DCHECK(!IsCurrent());
201 while (!::PostThreadMessage(thread_.GetThreadRef(), WM_QUIT, 0, 0)) { 207 while (!::PostThreadMessage(thread_.GetThreadRef(), WM_QUIT, 0, 0)) {
202 RTC_CHECK_EQ(ERROR_NOT_ENOUGH_QUOTA, ::GetLastError()); 208 RTC_CHECK_EQ(ERROR_NOT_ENOUGH_QUOTA, ::GetLastError());
203 Sleep(1); 209 Sleep(1);
204 } 210 }
205 thread_.Stop(); 211 thread_.Stop();
212 ::CloseHandle(in_queue_);
206 } 213 }
207 214
208 // static 215 // static
209 TaskQueue* TaskQueue::Current() { 216 TaskQueue* TaskQueue::Current() {
210 return static_cast<TaskQueue*>(::TlsGetValue(GetQueuePtrTls())); 217 return static_cast<TaskQueue*>(::TlsGetValue(GetQueuePtrTls()));
211 } 218 }
212 219
213 // static 220 // static
214 bool TaskQueue::IsCurrent(const char* queue_name) { 221 bool TaskQueue::IsCurrent(const char* queue_name) {
215 TaskQueue* current = Current(); 222 TaskQueue* current = Current();
216 return current && current->thread_.name().compare(queue_name) == 0; 223 return current && current->thread_.name().compare(queue_name) == 0;
217 } 224 }
218 225
219 bool TaskQueue::IsCurrent() const { 226 bool TaskQueue::IsCurrent() const {
220 return IsThreadRefEqual(thread_.GetThreadRef(), CurrentThreadRef()); 227 return IsThreadRefEqual(thread_.GetThreadRef(), CurrentThreadRef());
221 } 228 }
222 229
223 void TaskQueue::PostTask(std::unique_ptr<QueuedTask> task) { 230 void TaskQueue::PostTask(std::unique_ptr<QueuedTask> task) {
224 if (::PostThreadMessage(thread_.GetThreadRef(), WM_RUN_TASK, 0, 231 rtc::CritScope lock(&pending_lock_);
225 reinterpret_cast<LPARAM>(task.get()))) { 232 pending_.push(std::move(task));
226 task.release(); 233 ::SetEvent(in_queue_);
227 }
228 } 234 }
229 235
230 void TaskQueue::PostDelayedTask(std::unique_ptr<QueuedTask> task, 236 void TaskQueue::PostDelayedTask(std::unique_ptr<QueuedTask> task,
231 uint32_t milliseconds) { 237 uint32_t milliseconds) {
232 if (!milliseconds) { 238 if (!milliseconds) {
233 PostTask(std::move(task)); 239 PostTask(std::move(task));
234 return; 240 return;
235 } 241 }
236 242
237 // TODO(tommi): Avoid this allocation. It is currently here since 243 // TODO(tommi): Avoid this allocation. It is currently here since
(...skipping 23 matching lines...) Expand all
261 delete reply_task_ptr; 267 delete reply_task_ptr;
262 } 268 }
263 }); 269 });
264 } 270 }
265 271
266 void TaskQueue::PostTaskAndReply(std::unique_ptr<QueuedTask> task, 272 void TaskQueue::PostTaskAndReply(std::unique_ptr<QueuedTask> task,
267 std::unique_ptr<QueuedTask> reply) { 273 std::unique_ptr<QueuedTask> reply) {
268 return PostTaskAndReply(std::move(task), std::move(reply), Current()); 274 return PostTaskAndReply(std::move(task), std::move(reply), Current());
269 } 275 }
270 276
277 void TaskQueue::RunPendingTasks() {
278 while (true) {
279 std::unique_ptr<QueuedTask> task;
280 {
281 rtc::CritScope lock(&pending_lock_);
282 if (pending_.empty())
283 break;
284 task = std::move(pending_.front());
285 pending_.pop();
286 }
287
288 if (!task->Run())
289 task.release();
290 }
291 }
292
271 // static 293 // static
272 void TaskQueue::ThreadMain(void* context) { 294 void TaskQueue::ThreadMain(void* context) {
273 ThreadState state; 295 ThreadState state(static_cast<TaskQueue*>(context)->in_queue_);
274 state.RunThreadMain(); 296 state.RunThreadMain();
275 } 297 }
276 298
277 void TaskQueue::ThreadState::RunThreadMain() { 299 void TaskQueue::ThreadState::RunThreadMain() {
300 HANDLE handles[2] = { *timer_.event_for_wait(), in_queue_ };
278 while (true) { 301 while (true) {
279 // Make sure we do an alertable wait as that's required to allow APCs to run 302 // Make sure we do an alertable wait as that's required to allow APCs to run
280 // (e.g. required for InitializeQueueThread and stopping the thread in 303 // (e.g. required for InitializeQueueThread and stopping the thread in
281 // PlatformThread). 304 // PlatformThread).
282 DWORD result = ::MsgWaitForMultipleObjectsEx( 305 DWORD result = ::MsgWaitForMultipleObjectsEx(
283 1, timer_.event_for_wait(), INFINITE, QS_ALLEVENTS, MWMO_ALERTABLE); 306 arraysize(handles), handles, INFINITE, QS_ALLEVENTS, MWMO_ALERTABLE);
284 RTC_CHECK_NE(WAIT_FAILED, result); 307 RTC_CHECK_NE(WAIT_FAILED, result);
285 if (result == (WAIT_OBJECT_0 + 1)) { 308 if (result == (WAIT_OBJECT_0 + 2)) {
286 // There are messages in the message queue that need to be handled. 309 // There are messages in the message queue that need to be handled.
287 if (!ProcessQueuedMessages()) 310 if (!ProcessQueuedMessages())
288 break; 311 break;
289 } else if (result == WAIT_OBJECT_0) { 312 }
313
314 if (result == WAIT_OBJECT_0 || (!timer_tasks_.empty() &&
315 ::WaitForSingleObject(*timer_.event_for_wait(), 0) == WAIT_OBJECT_0)) {
290 // The multimedia timer was signaled. 316 // The multimedia timer was signaled.
291 timer_.Cancel(); 317 timer_.Cancel();
292 RTC_DCHECK(!timer_tasks_.empty());
293 RunDueTasks(); 318 RunDueTasks();
294 ScheduleNextTimer(); 319 ScheduleNextTimer();
295 } else { 320 }
296 RTC_DCHECK_EQ(WAIT_IO_COMPLETION, result); 321
322 if (result == (WAIT_OBJECT_0 + 1)) {
323 ::ResetEvent(in_queue_);
324 TaskQueue::Current()->RunPendingTasks();
297 } 325 }
298 } 326 }
299 } 327 }
300 328
301 bool TaskQueue::ThreadState::ProcessQueuedMessages() { 329 bool TaskQueue::ThreadState::ProcessQueuedMessages() {
302 MSG msg = {}; 330 MSG msg = {};
331 // To protect against overly busy message queues, we limit the time
332 // we process tasks to a few milliseconds. If we don't do that, there's
333 // a chance that timer tasks won't ever run.
334 static const int kMaxTaskProcessingTimeMs = 500;
335 auto start = GetTick();
303 while (::PeekMessage(&msg, nullptr, 0, 0, PM_REMOVE) && 336 while (::PeekMessage(&msg, nullptr, 0, 0, PM_REMOVE) &&
304 msg.message != WM_QUIT) { 337 msg.message != WM_QUIT) {
305 if (!msg.hwnd) { 338 if (!msg.hwnd) {
306 switch (msg.message) { 339 switch (msg.message) {
340 // TODO(tommi): Stop using this way of queueing tasks.
307 case WM_RUN_TASK: { 341 case WM_RUN_TASK: {
308 QueuedTask* task = reinterpret_cast<QueuedTask*>(msg.lParam); 342 QueuedTask* task = reinterpret_cast<QueuedTask*>(msg.lParam);
309 if (task->Run()) 343 if (task->Run())
310 delete task; 344 delete task;
311 break; 345 break;
312 } 346 }
313 case WM_QUEUE_DELAYED_TASK: { 347 case WM_QUEUE_DELAYED_TASK: {
314 std::unique_ptr<DelayedTaskInfo> info( 348 std::unique_ptr<DelayedTaskInfo> info(
315 reinterpret_cast<DelayedTaskInfo*>(msg.lParam)); 349 reinterpret_cast<DelayedTaskInfo*>(msg.lParam));
316 bool need_to_schedule_timers = 350 bool need_to_schedule_timers =
(...skipping 15 matching lines...) Expand all
332 break; 366 break;
333 } 367 }
334 default: 368 default:
335 RTC_NOTREACHED(); 369 RTC_NOTREACHED();
336 break; 370 break;
337 } 371 }
338 } else { 372 } else {
339 ::TranslateMessage(&msg); 373 ::TranslateMessage(&msg);
340 ::DispatchMessage(&msg); 374 ::DispatchMessage(&msg);
341 } 375 }
376
377 if (GetTick() > start + kMaxTaskProcessingTimeMs)
378 break;
342 } 379 }
343 return msg.message != WM_QUIT; 380 return msg.message != WM_QUIT;
344 } 381 }
345 382
346 void TaskQueue::ThreadState::RunDueTasks() { 383 void TaskQueue::ThreadState::RunDueTasks() {
347 RTC_DCHECK(!timer_tasks_.empty()); 384 RTC_DCHECK(!timer_tasks_.empty());
348 auto now = GetTick(); 385 auto now = GetTick();
349 do { 386 do {
350 const auto& top = timer_tasks_.top(); 387 const auto& top = timer_tasks_.top();
351 if (top.due_time() > now) 388 if (top.due_time() > now)
(...skipping 17 matching lines...) Expand all
369 406
370 void TaskQueue::ThreadState::CancelTimers() { 407 void TaskQueue::ThreadState::CancelTimers() {
371 timer_.Cancel(); 408 timer_.Cancel();
372 if (timer_id_) { 409 if (timer_id_) {
373 ::KillTimer(nullptr, timer_id_); 410 ::KillTimer(nullptr, timer_id_);
374 timer_id_ = 0; 411 timer_id_ = 0;
375 } 412 }
376 } 413 }
377 414
378 } // namespace rtc 415 } // namespace rtc
OLDNEW
« no previous file with comments | « webrtc/base/task_queue.h ('k') | no next file » | no next file with comments »

Powered by Google App Engine
This is Rietveld 408576698