Skip to content

Commit b781b1e

Browse files
Fix waker behavior when invoked before a poll finishes (#54)
It's possible for a task's waker to be invoked in the middle of a call to that task's `poll` by the executor. We had accounted for that possibility if the task called *its own* waker, but that's not good enough: the waker can escape to other threads that can invoke it before `poll` finishes (e.g., if the task blocks to acquire a lock). This change fixes the waker behavior by clarifying the semantics of a call to `wake`: a task whose waker is invoked should not be blocked *when it next returns Pending to the executor*, and should be woken if that has already happened. To do this, we introduce a new `Sleeping` state for tasks, that has the same semantics as `Blocked` but that is recognized by waker invocations, which will only unblock a task in `Sleeping` state. This also removes the special case "woken by self" behavior -- being woken by *any* thread should be enough to trigger this sleep logic.
1 parent b37e984 commit b781b1e

File tree

7 files changed

+158
-50
lines changed

7 files changed

+158
-50
lines changed

src/asynch.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,7 @@ pub fn block_on<F: Future>(future: F) -> F::Output {
148148
match future.as_mut().poll(cx) {
149149
Poll::Ready(result) => break result,
150150
Poll::Pending => {
151-
ExecutionState::with(|state| {
152-
state.current_mut().block_unless_self_woken();
153-
});
151+
ExecutionState::with(|state| state.current_mut().sleep_unless_woken());
154152
}
155153
}
156154

src/runtime/execution.rs

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::runtime::failure::{init_panic_hook, persist_failure, persist_task_failure};
22
use crate::runtime::storage::{StorageKey, StorageMap};
33
use crate::runtime::task::clock::VectorClock;
4-
use crate::runtime::task::{Task, TaskId, TaskState, DEFAULT_INLINE_TASKS};
4+
use crate::runtime::task::{Task, TaskId, DEFAULT_INLINE_TASKS};
55
use crate::runtime::thread::continuation::PooledContinuation;
66
use crate::scheduler::{Schedule, Scheduler};
77
use crate::{Config, MaxSteps};
@@ -62,7 +62,12 @@ impl Execution {
6262

6363
EXECUTION_STATE.set(&state, move || {
6464
// Spawn `f` as the first task
65-
ExecutionState::spawn_thread(f, config.stack_size, None, Some(VectorClock::new()));
65+
ExecutionState::spawn_thread(
66+
f,
67+
config.stack_size,
68+
Some("main-thread".to_string()),
69+
Some(VectorClock::new()),
70+
);
6671

6772
// Run the test to completion
6873
while self.step(config) {}
@@ -93,21 +98,29 @@ impl Execution {
9398
NextStep::Task(Rc::clone(&task.continuation))
9499
}
95100
ScheduledTask::Finished => {
96-
let task_states = state
97-
.tasks
98-
.iter()
99-
.map(|t| (t.id, t.state, t.detached))
100-
.collect::<SmallVec<[_; DEFAULT_INLINE_TASKS]>>();
101-
if task_states
102-
.iter()
103-
.any(|(_, s, detached)| !detached && *s == TaskState::Blocked)
104-
{
101+
// The scheduler decided we're finished, so there are either no runnable tasks,
102+
// or all runnable tasks are detached and there are no unfinished attached
103+
// tasks. Therefore, it's a deadlock if there are unfinished attached tasks.
104+
if state.tasks.iter().any(|t| !t.finished() && !t.detached) {
105+
let blocked_tasks = state
106+
.tasks
107+
.iter()
108+
.filter(|t| !t.finished())
109+
.map(|t| {
110+
format!(
111+
"{} (task {}{}{})",
112+
t.name().unwrap_or_else(|| "<unknown>".to_string()),
113+
t.id().0,
114+
if t.detached { ", detached" } else { "" },
115+
if t.sleeping() { ", pending future" } else { "" },
116+
)
117+
})
118+
.collect::<Vec<_>>();
105119
NextStep::Failure(
106-
format!("deadlock! runnable tasks: {:?}", task_states),
120+
format!("deadlock! blocked tasks: [{}]", blocked_tasks.join(", ")),
107121
state.current_schedule.clone(),
108122
)
109123
} else {
110-
debug_assert!(state.tasks.iter().all(|t| t.detached || t.finished()));
111124
NextStep::Finished
112125
}
113126
}
@@ -502,21 +515,21 @@ impl ExecutionState {
502515
_ => {}
503516
}
504517

505-
let mut blocked_attached = false;
518+
let mut unfinished_attached = false;
506519
let runnable = self
507520
.tasks
508521
.iter()
509-
.inspect(|t| blocked_attached = blocked_attached || (t.blocked() && !t.detached))
522+
.inspect(|t| unfinished_attached = unfinished_attached || (!t.finished() && !t.detached))
510523
.filter(|t| t.runnable())
511524
.map(|t| t.id)
512525
.collect::<SmallVec<[_; DEFAULT_INLINE_TASKS]>>();
513526

514527
// We should finish execution when either
515528
// (1) There are no runnable tasks, or
516-
// (2) All runnable tasks have been detached AND there are no blocked attached tasks
517-
// If there are some blocked attached tasks and all runnable tasks are detached,
518-
// we must run some detached task so that blocked attached tasks may become unblocked.
519-
if runnable.is_empty() || (!blocked_attached && runnable.iter().all(|id| self.get(*id).detached)) {
529+
// (2) All runnable tasks have been detached AND there are no unfinished attached tasks
530+
// If there are some unfinished attached tasks and all runnable tasks are detached, we must
531+
// run some detached task to give them a chance to unblock some unfinished attached task.
532+
if runnable.is_empty() || (!unfinished_attached && runnable.iter().all(|id| self.get(*id).detached)) {
520533
self.next_task = ScheduledTask::Finished;
521534
return Ok(());
522535
}

src/runtime/storage.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ pub(crate) struct StorageKey(pub usize, pub usize); // (identifier, type)
1111
/// Values are Option<_> because we need to be able to incrementally destruct them, as it's valid
1212
/// for TLS destructors to initialize new TLS slots. When a slot is destructed, its key is removed
1313
/// from `order` and its value is replaced with None.
14+
#[derive(Debug)]
1415
pub(crate) struct StorageMap {
1516
locals: HashMap<StorageKey, Option<Box<dyn Any>>>,
1617
order: VecDeque<StorageKey>,

src/runtime/task/mod.rs

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ pub(crate) const DEFAULT_INLINE_TASKS: usize = 16;
3838

3939
/// A `Task` represents a user-level unit of concurrency. Each task has an `id` that is unique within
4040
/// the execution, and a `state` reflecting whether the task is runnable (enabled) or not.
41+
#[derive(Debug)]
4142
pub(crate) struct Task {
4243
pub(super) id: TaskId,
4344
pub(super) state: TaskState,
@@ -50,8 +51,8 @@ pub(crate) struct Task {
5051
waiter: Option<TaskId>,
5152

5253
waker: Waker,
53-
// Remember whether the waker was invoked while we were running so we don't re-block
54-
woken_by_self: bool,
54+
// Remember whether the waker was invoked while we were running
55+
woken: bool,
5556

5657
name: Option<String>,
5758

@@ -76,7 +77,7 @@ impl Task {
7677
clock,
7778
waiter: None,
7879
waker,
79-
woken_by_self: false,
80+
woken: false,
8081
detached: false,
8182
name,
8283
local_storage: StorageMap::new(),
@@ -106,10 +107,7 @@ impl Task {
106107
let waker = ExecutionState::with(|state| state.current_mut().waker());
107108
let cx = &mut Context::from_waker(&waker);
108109
while future.as_mut().poll(cx).is_pending() {
109-
ExecutionState::with(|state| {
110-
// We need to block before thread::switch() unless we woke ourselves up
111-
state.current_mut().block_unless_self_woken();
112-
});
110+
ExecutionState::with(|state| state.current_mut().sleep_unless_woken());
113111
thread::switch();
114112
}
115113
},
@@ -132,6 +130,10 @@ impl Task {
132130
self.state == TaskState::Blocked
133131
}
134132

133+
pub(crate) fn sleeping(&self) -> bool {
134+
self.state == TaskState::Sleeping
135+
}
136+
135137
pub(crate) fn finished(&self) -> bool {
136138
self.state == TaskState::Finished
137139
}
@@ -149,6 +151,11 @@ impl Task {
149151
self.state = TaskState::Blocked;
150152
}
151153

154+
pub(crate) fn sleep(&mut self) {
155+
assert!(self.state != TaskState::Finished);
156+
self.state = TaskState::Sleeping;
157+
}
158+
152159
pub(crate) fn unblock(&mut self) {
153160
// Note we don't assert the task is blocked here. For example, a task invoking its own waker
154161
// will not be blocked when this is called.
@@ -161,23 +168,25 @@ impl Task {
161168
self.state = TaskState::Finished;
162169
}
163170

164-
/// Potentially block this task after it was polled by the executor.
171+
/// Potentially put this task to sleep after it was polled by the executor, unless someone has
172+
/// called its waker first.
165173
///
166-
/// A synchronous Task should never call this, because we want threads to be
167-
/// enabled-by-default to avoid bugs where Shuttle incorrectly omits a potential execution.
168-
/// We also need to handle a special case where a task invoked its own waker, in which case
169-
/// we should not block the task.
170-
pub(crate) fn block_unless_self_woken(&mut self) {
171-
let was_woken_by_self = std::mem::replace(&mut self.woken_by_self, false);
172-
if !was_woken_by_self {
173-
self.block();
174+
/// A synchronous Task should never call this, because we want threads to be enabled-by-default
175+
/// to avoid bugs where Shuttle incorrectly omits a potential execution.
176+
pub(crate) fn sleep_unless_woken(&mut self) {
177+
let was_woken = std::mem::replace(&mut self.woken, false);
178+
if !was_woken {
179+
self.sleep();
174180
}
175181
}
176182

177-
/// Remember that we have been unblocked while we were currently running, and therefore should
178-
/// not be blocked again by `block_unless_self_woken`.
179-
pub(super) fn set_woken_by_self(&mut self) {
180-
self.woken_by_self = true;
183+
/// Remember that our waker has been called, and so we should not block the next time the
184+
/// executor tries to put us to sleep.
185+
pub(super) fn wake(&mut self) {
186+
self.woken = true;
187+
if self.state == TaskState::Sleeping {
188+
self.unblock();
189+
}
181190
}
182191

183192
/// Register a waiter for this thread to terminate. Returns a boolean indicating whether the
@@ -240,8 +249,13 @@ impl Task {
240249

241250
#[derive(PartialEq, Eq, Clone, Copy, Debug)]
242251
pub(crate) enum TaskState {
252+
/// Available to be scheduled
243253
Runnable,
254+
/// Blocked in a synchronization operation
244255
Blocked,
256+
/// A `Future` that returned `Pending` is waiting to be woken up
257+
Sleeping,
258+
/// Task has finished
245259
Finished,
246260
}
247261

src/runtime/task/waker.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,7 @@ unsafe fn raw_waker_wake(data: *const ()) {
3838
return;
3939
}
4040

41-
waiter.unblock();
42-
43-
let current = state.current_mut();
44-
if current.id() == task_id {
45-
current.set_woken_by_self();
46-
}
41+
waiter.wake();
4742
});
4843
}
4944

src/runtime/thread/continuation.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,12 @@ impl DerefMut for PooledContinuation {
247247
}
248248
}
249249

250+
impl std::fmt::Debug for PooledContinuation {
251+
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
252+
f.debug_struct("PooledContinuation").finish()
253+
}
254+
}
255+
250256
// Safety: these aren't sent across real threads
251257
unsafe impl Send for PooledContinuation {}
252258

tests/asynch/waker.rs

Lines changed: 83 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
use futures::future::poll_fn;
2+
use shuttle::sync::atomic::{AtomicBool, Ordering};
3+
use shuttle::sync::Mutex;
4+
use shuttle::{asynch, check_dfs, thread};
15
use std::future::Future;
26
use std::pin::Pin;
7+
use std::sync::Arc;
38
use std::task::{Context, Poll, Waker};
4-
5-
use shuttle::{asynch, check_dfs};
69
use test_env_log::test;
710

811
#[test]
@@ -49,3 +52,81 @@ fn wake_after_finish() {
4952
None,
5053
)
5154
}
55+
56+
// Test that we can pass wakers across threads and have them work correctly
57+
#[test]
58+
fn wake_during_poll() {
59+
check_dfs(
60+
|| {
61+
let waker: Arc<Mutex<Option<Waker>>> = Arc::new(Mutex::new(None));
62+
let waker_clone = Arc::clone(&waker);
63+
let signal = Arc::new(AtomicBool::new(false));
64+
let signal_clone = Arc::clone(&signal);
65+
66+
// This thread might invoke `wake` before the other task finishes running a single
67+
// invocation of `poll`. If that happens, that task must not be blocked.
68+
thread::spawn(move || {
69+
signal_clone.store(true, Ordering::SeqCst);
70+
71+
if let Some(waker) = waker_clone.lock().unwrap().take() {
72+
waker.wake();
73+
}
74+
});
75+
76+
asynch::block_on(poll_fn(move |cx| {
77+
*waker.lock().unwrap() = Some(cx.waker().clone());
78+
79+
if signal.load(Ordering::SeqCst) {
80+
Poll::Ready(())
81+
} else {
82+
Poll::Pending
83+
}
84+
}));
85+
},
86+
None,
87+
);
88+
}
89+
90+
// Test that a waker invocation doesn't unblock a task that is blocked due to synchronization
91+
// operations
92+
#[test]
93+
fn wake_during_blocked_poll() {
94+
static RAN_WAKER: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
95+
check_dfs(
96+
|| {
97+
let waker: Arc<Mutex<Option<Waker>>> = Arc::new(Mutex::new(None));
98+
let waker_clone = Arc::clone(&waker);
99+
let counter = Arc::new(Mutex::new(0));
100+
let counter_clone = Arc::clone(&counter);
101+
102+
thread::spawn(move || {
103+
let mut counter = counter_clone.lock().unwrap();
104+
thread::yield_now();
105+
*counter += 1;
106+
});
107+
108+
// If this `wake()` invocation happens while the thread above holds the `counter` lock
109+
// and the `block_on` task below is blocked waiting to acquire that same lock, then
110+
// `wake` must not unblock the `block_on` task. That is, `wake` should prevent the task
111+
// from being blocked *the next time it returns Pending*, not just any time it is
112+
// blocked.
113+
thread::spawn(move || {
114+
if let Some(waker) = waker_clone.lock().unwrap().take() {
115+
RAN_WAKER.store(true, Ordering::SeqCst);
116+
waker.wake();
117+
}
118+
});
119+
120+
asynch::block_on(poll_fn(move |cx| {
121+
*waker.lock().unwrap() = Some(cx.waker().clone());
122+
123+
let mut counter = counter.lock().unwrap();
124+
*counter += 1;
125+
126+
Poll::Ready(())
127+
}));
128+
},
129+
None,
130+
);
131+
assert!(RAN_WAKER.load(Ordering::SeqCst), "waker was not invoked by any test");
132+
}

0 commit comments

Comments
 (0)