diff --git a/hyperactor/src/config.rs b/hyperactor/src/config.rs index c10a66a3..815c422d 100644 --- a/hyperactor/src/config.rs +++ b/hyperactor/src/config.rs @@ -43,6 +43,9 @@ declare_attrs! { /// Flag indicating if this is a managed subprocess pub attr IS_MANAGED_SUBPROCESS: bool = false; + + /// Maximum number of supervision events that can be buffered by client handlers + pub attr MAX_SUPERVISION_EVENTS: usize = 10; } /// Load configuration from environment variables diff --git a/hyperactor/src/proc.rs b/hyperactor/src/proc.rs index 20b5aed9..a8af0619 100644 --- a/hyperactor/src/proc.rs +++ b/hyperactor/src/proc.rs @@ -1249,7 +1249,7 @@ struct InstanceState { signal: PortHandle, /// The actor's supervision port. This is used to send - /// supervision event to the actor. + /// supervision event to the actor (usually by its children). supervision_port: PortHandle, /// An observer that stores the current status of the actor. diff --git a/hyperactor_mesh/src/actor_mesh.rs b/hyperactor_mesh/src/actor_mesh.rs index f517d4e6..82ef482a 100644 --- a/hyperactor_mesh/src/actor_mesh.rs +++ b/hyperactor_mesh/src/actor_mesh.rs @@ -27,6 +27,7 @@ use hyperactor::message::Bindings; use hyperactor::message::Castable; use hyperactor::message::IndexedErasedUnbound; use hyperactor::message::Unbind; +use hyperactor::supervision::ActorSupervisionEvent; use ndslice::Range; use ndslice::Selection; use ndslice::Shape; @@ -36,6 +37,7 @@ use ndslice::dsl; use ndslice::selection::ReifyView; use serde::Deserialize; use serde::Serialize; +use tokio::sync::mpsc; use crate::Mesh; use crate::comm::multicast::CastMessage; @@ -141,26 +143,35 @@ pub struct RootActorMesh<'a, A: RemoteActor> { proc_mesh: ProcMeshRef<'a>, name: String, pub(crate) ranks: Vec>, // temporary until we remove `ArcActorMesh`. + actor_supervision_rx: mpsc::Receiver, } impl<'a, A: RemoteActor> RootActorMesh<'a, A> { - pub(crate) fn new(proc_mesh: &'a ProcMesh, name: String, ranks: Vec>) -> Self { + pub(crate) fn new( + proc_mesh: &'a ProcMesh, + name: String, + actor_supervision_rx: mpsc::Receiver, + ranks: Vec>, + ) -> Self { Self { proc_mesh: ProcMeshRef::Borrowed(proc_mesh), name, ranks, + actor_supervision_rx, } } pub(crate) fn new_shared( proc_mesh: Arc, name: String, + actor_supervision_rx: mpsc::Receiver, ranks: Vec>, ) -> Self { Self { proc_mesh: ProcMeshRef::Shared(proc_mesh), name, ranks, + actor_supervision_rx, } } @@ -198,6 +209,21 @@ impl<'a, A: RemoteActor> RootActorMesh<'a, A> { } Ok(()) } + + /// An event stream of proc events. Each ProcMesh can produce only one such + /// stream, returning None after the first call. + pub async fn next(&mut self) -> Option { + let result = self.actor_supervision_rx.recv().await; + match result.as_ref() { + Some(event) => { + tracing::info!("Received supervision event: {event:?}"); + } + None => { + tracing::info!("Closed!"); + } + }; + result + } } #[async_trait] diff --git a/hyperactor_mesh/src/proc_mesh/mod.rs b/hyperactor_mesh/src/proc_mesh/mod.rs index e7130922..123390bd 100644 --- a/hyperactor_mesh/src/proc_mesh/mod.rs +++ b/hyperactor_mesh/src/proc_mesh/mod.rs @@ -36,6 +36,8 @@ use hyperactor::supervision::ActorSupervisionEvent; use ndslice::Range; use ndslice::Shape; use ndslice::ShapeError; +use tokio::sync::Mutex; +use tokio::sync::mpsc; use crate::CommActor; use crate::Mesh; @@ -64,12 +66,14 @@ fn global_router() -> &'static MailboxRouter { GLOBAL_ROUTER.get_or_init(MailboxRouter::new) } +type ActorEventRouter = Arc>>>; /// A ProcMesh maintains a mesh of procs whose lifecycles are managed by /// an allocator. pub struct ProcMesh { // The underlying set of events. It is None if it has been transferred to // a proc event observer. event_state: Option, + actor_event_router: ActorEventRouter, shape: Shape, ranks: Vec<(ProcId, (ChannelAddr, ActorRef))>, #[allow(dead_code)] // will be used in subsequent diff @@ -198,11 +202,13 @@ impl ProcMesh { global_router().bind(alloc.world_id().clone().into(), router.clone()); global_router().bind(client_proc_id.into(), router.clone()); + // TODO: No actor bound to "supervisor" yet. let supervisor = client_proc.attach("supervisor")?; let (supervison_port, supervision_events) = supervisor.open_port(); // Now, configure the full mesh, so that the local agents are wired up to // our router. + // No actor bound to this "client" yet let client = client_proc.attach("client")?; // Map of procs -> channel addresses @@ -277,6 +283,7 @@ impl ProcMesh { alloc: Box::new(alloc), supervision_events, }), + actor_event_router: Arc::new(Mutex::new(HashMap::new())), shape, ranks: proc_ids .into_iter() @@ -359,11 +366,23 @@ impl ProcMesh { where A::Params: RemoteMessage, { - Ok(RootActorMesh::new( + let actor_supervision_buffer_len = + hyperactor::config::global::get(hyperactor::config::MAX_SUPERVISION_EVENTS); + let (tx, rx) = mpsc::channel::(actor_supervision_buffer_len); + { + // Instantiate supervision routing BEFORE spawning the actor mesh. + self.actor_event_router + .lock() + .await + .insert(actor_name.to_string(), tx); + } + let root_mesh = RootActorMesh::new( self, actor_name.to_string(), + rx, Self::spawn_on_procs::(&self.client, self.agents(), actor_name, params).await?, - )) + ); + Ok(root_mesh) } /// A client used to communicate with any member of this mesh. @@ -390,8 +409,10 @@ impl ProcMesh { .enumerate() .map(|(rank, (proc_id, _))| (proc_id.clone(), rank)) .collect(), + actor_event_router: self.actor_event_router.clone(), }) } + pub fn shape(&self) -> &Shape { &self.shape } @@ -420,11 +441,14 @@ impl fmt::Display for ProcEvent { } } +type ActorMeshName = String; + /// An event stream of [`ProcEvent`] // TODO: consider using streams for this. pub struct ProcEvents { event_state: EventState, ranks: HashMap, + actor_event_router: ActorEventRouter, } impl ProcEvents { @@ -436,6 +460,11 @@ impl ProcEvents { result = self.event_state.alloc.next() => { // Don't disable the outer branch on None: this is always terminal. let Some(alloc_event) = result else { + { + let mut map = self.actor_event_router.lock().await; + // Remove all values in map + map.clear(); + } break None; }; @@ -452,11 +481,22 @@ impl ProcEvents { break Some(ProcEvent::Stopped(*rank, reason)); } Ok(event) = self.event_state.supervision_events.recv() => { - let (actor_id, actor_status) = event.into_inner(); + let (actor_id, actor_status) = event.clone().into_inner(); let Some(rank) = self.ranks.get(actor_id.proc_id()) else { tracing::warn!("received supervision event for unmapped actor {}", actor_id); continue; }; + // transmit to the correct root actor mesh. + { + let map = self.actor_event_router.lock().await; + let Some(tx) = map.get(actor_id.name()) else { + tracing::warn!("received supervision event for unregistered actor {}", actor_id); + continue; + }; + tx.send(event).await.unwrap(); + } + // TODO: Actor supervision events need to be wired to the frontend. + // TODO: This event should be handled by the proc mesh if unhandled by actor mesh. break Some(ProcEvent::Crashed(*rank, actor_status.to_string())) } } @@ -487,9 +527,20 @@ impl SharedSpawnable for Arc { where A::Params: RemoteMessage, { + let actor_supervision_buffer = + hyperactor::config::global::get(hyperactor::config::MAX_SUPERVISION_EVENTS); + let (tx, rx) = mpsc::channel::(actor_supervision_buffer); + { + // Instantiate supervision routing BEFORE spawning the actor mesh. + self.actor_event_router + .lock() + .await + .insert(actor_name.to_string(), tx); + } Ok(RootActorMesh::new_shared( Arc::clone(self), actor_name.to_string(), + rx, ProcMesh::spawn_on_procs::(&self.client, self.agents(), actor_name, params).await?, )) } @@ -639,7 +690,7 @@ mod tests { let mut mesh = ProcMesh::allocate(alloc).await.unwrap(); let mut events = mesh.events().unwrap(); - let actors = mesh.spawn::("failing", &()).await.unwrap(); + let mut actors = mesh.spawn::("failing", &()).await.unwrap(); actors .cast( @@ -653,6 +704,8 @@ mod tests { ProcEvent::Crashed(0, reason) if reason.contains("failmonkey") ); + assert_matches!(actors.next().await, Some(_)); + stop(); assert_matches!( events.next().await.unwrap(), @@ -664,5 +717,6 @@ mod tests { ); assert!(events.next().await.is_none()); + assert!(actors.next().await.is_none()); } }