Skip to content

Commit a2d4277

Browse files
suofacebook-github-bot
authored andcommitted
pin PythonActors to a single thread (#267)
Summary: Pull Request resolved: #267 As we did with StreamActor before it, it is good to run PythonActor handlers on one thread consistently, since a lot of the Python code we have assumes that is called from a single thread, makes use of thread-local state, etc. So, we spawn one thread per PythonActor for it to run on. TODO: this does NOT preserve thread-local state between async endpoints and sync endpoints, which is a huge issue. I think the correct solution there is detect whether an actor has any async endpoints, and if so, switch to a mode where absolutely all Python code is run on the asyncio event loop. That's a deeper refactor though, so will do this to unblock for now. Reviewed By: mariusae Differential Revision: D76603819 fbshipit-source-id: d8101084fe41a10d163a26dd450e1e0e4a2614ac
1 parent 7186689 commit a2d4277

File tree

2 files changed

+95
-0
lines changed

2 files changed

+95
-0
lines changed

monarch_hyperactor/src/actor.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,11 @@ use pyo3::types::PyType;
3636
use serde::Deserialize;
3737
use serde::Serialize;
3838
use serde_bytes::ByteBuf;
39+
use tokio::runtime::Handle;
3940
use tokio::sync::Mutex;
4041
use tokio::sync::oneshot;
42+
use tokio::task::JoinHandle;
43+
use tracing::span::Id;
4144

4245
use crate::mailbox::PyMailbox;
4346
use crate::proc::InstanceWrapper;
@@ -284,6 +287,56 @@ impl Actor for PythonActor {
284287
Ok(Self { actor })
285288
})?)
286289
}
290+
291+
/// Specialize spawn_server_task for PythonActor, because we want to run the stream on a
292+
/// dedicated OS thread. We do this to guarantee tha all Python code is
293+
/// executed on the same thread, since often Python code uses thread-local
294+
/// state or otherwise assumes that it is called only from a single thread.
295+
fn spawn_server_task<F>(future: F) -> JoinHandle<F::Output>
296+
where
297+
F: Future + Send + 'static,
298+
F::Output: Send + 'static,
299+
{
300+
let (join_tx, join_rx) = tokio::sync::oneshot::channel();
301+
// It is important that we spawn a standalone thread for the work here,
302+
// as opposed to using `spawn_blocking` to spawn a tokio-managed thread.
303+
// This is because the worker stream may call uninterruptible FFI code
304+
// that can deadlock (CUDA, NCCL).
305+
// If we use a tokio-managed blocking thread, then runtime teardown will
306+
// try to wait for tasks on that thread to reach an await point, and
307+
// hang forever.
308+
let builder = std::thread::Builder::new().name("python-actor".to_string());
309+
let _thread_handle = builder.spawn(move || {
310+
// Spawn a new thread with a single-threaded tokio runtime to run the
311+
// actor loop. We avoid the current-threaded runtime, so that we can
312+
// use `block_in_place` for nested async-to-sync-to-async flows.
313+
let rt = tokio::runtime::Builder::new_multi_thread()
314+
.worker_threads(1)
315+
.enable_io()
316+
.build()
317+
.unwrap();
318+
rt.block_on(async {
319+
tokio::task::block_in_place(|| {
320+
// Allow e.g. destructing py objects on this thread, which
321+
// can happen at shutdown when the a stream actors env map
322+
// for rvalues is dropped (e.g. P1673311499).
323+
// https://github.com/PyO3/pyo3/discussions/3499
324+
Python::with_gil(|py| {
325+
py.allow_threads(|| {
326+
let result = Handle::current().block_on(future);
327+
if join_tx.send(result).is_err() {
328+
panic!("could not send join result")
329+
}
330+
})
331+
})
332+
})
333+
})
334+
});
335+
336+
// In order to bridge the synchronous join handle with the async world,
337+
// smuggle the result through a channel.
338+
tokio::spawn(async move { join_rx.await.unwrap() })
339+
}
287340
}
288341

289342
/// Get the event loop state to run PythonActor handlers in. We construct a

python/tests/test_python_actors.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import operator
99
import os
1010
import re
11+
import threading
1112
from types import ModuleType
1213
from unittest.mock import AsyncMock, patch
1314

@@ -549,3 +550,44 @@ def _patch_output(msg):
549550

550551
with pytest.raises(monarch.actor_mesh.ActorError, match="ValueError: bad rank"):
551552
await fut
553+
554+
555+
class TLSActor(Actor):
556+
"""An actor that manages thread-local state."""
557+
558+
def __init__(self):
559+
self.local = threading.local()
560+
self.local.value = 0
561+
562+
@endpoint
563+
def increment(self):
564+
self.local.value += 1
565+
566+
@endpoint
567+
async def increment_async(self):
568+
self.local.value += 1
569+
570+
@endpoint
571+
def get(self):
572+
return self.local.value
573+
574+
@endpoint
575+
async def get_async(self):
576+
return self.local.value
577+
578+
579+
async def test_actor_tls() -> None:
580+
"""Test that thread-local state is respected."""
581+
pm = await proc_mesh(gpus=1)
582+
am = await pm.spawn("tls", TLSActor)
583+
await am.increment.call_one()
584+
# TODO(suo): TLS is NOT preserved across async/sync endpoints, because currently
585+
# we run async endpoints on a different thread than sync ones.
586+
# Will fix this in a followup diff.
587+
588+
# await am.increment_async.call_one()
589+
await am.increment.call_one()
590+
# await am.increment_async.call_one()
591+
592+
assert 2 == await am.get.call_one()
593+
# assert 4 == await am.get_async.call_one()

0 commit comments

Comments
 (0)