Skip to content

Commit 00a0ad1

Browse files
committed
[3/n tensor engine] hello tensor engine
Add the initialize controller class `mesh_controller` that can implement the tensor engine on top of the ProcMesh/ProcActor API. The only example is currently a "hello world" that allocates and fetches a tensor. Follow up PRs will integrate creating meshes this way into the testing code more deeply and fix the issues that come up with it. This design assumes that supervision of stuff is going to be handled by the actor system and that tensor compute can just rely on that for monitoring and stuckness detection stuff. This design has no ClientActor, and the ControllerActor only exists as a Instance handle for reading messages from the workers (which send some controller messages). This does not attempt to clean up the existing RustController system yet, since it isn't feature equivalent or tested with it. Differential Revision: [D75909313](https://our.internmc.facebook.com/intern/diff/D75909313/) ghstack-source-id: 288027712 Pull Request resolved: #187
1 parent 108c737 commit 00a0ad1

File tree

21 files changed

+798
-25
lines changed

21 files changed

+798
-25
lines changed

controller/src/history.rs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ enum RefStatus {
115115
/// borrows, drops etc. directly.
116116
#[derive(Debug)]
117117
#[allow(dead_code)]
118-
pub(crate) struct History {
118+
pub struct History {
119119
/// The first incomplete Seq for each rank. This is used to determine which
120120
/// Seqs are no longer relevant and can be purged from the history.
121121
first_incomplete_seqs: MinVector<Seq>,
@@ -198,7 +198,7 @@ where
198198
}
199199

200200
impl History {
201-
pub(crate) fn new(world_size: usize) -> Self {
201+
pub fn new(world_size: usize) -> Self {
202202
Self {
203203
first_incomplete_seqs: MinVector::new(vec![Seq::default(); world_size]),
204204
min_incomplete_seq: Seq::default(),
@@ -213,23 +213,23 @@ impl History {
213213
}
214214

215215
#[cfg(test)]
216-
pub(crate) fn first_incomplete_seqs(&self) -> &[Seq] {
216+
pub fn first_incomplete_seqs(&self) -> &[Seq] {
217217
self.first_incomplete_seqs.vec()
218218
}
219219

220-
pub(crate) fn first_incomplete_seqs_controller(&self) -> &[Seq] {
220+
pub fn first_incomplete_seqs_controller(&self) -> &[Seq] {
221221
self.first_incomplete_seqs_controller.vec()
222222
}
223223

224-
pub(crate) fn min_incomplete_seq_reported(&self) -> Seq {
224+
pub fn min_incomplete_seq_reported(&self) -> Seq {
225225
self.min_incompleted_seq_controller
226226
}
227227

228-
pub(crate) fn world_size(&self) -> usize {
228+
pub fn world_size(&self) -> usize {
229229
self.first_incomplete_seqs.len()
230230
}
231231

232-
pub(crate) fn delete_invocations_for_refs(&mut self, refs: Vec<Ref>) {
232+
pub fn delete_invocations_for_refs(&mut self, refs: Vec<Ref>) {
233233
self.marked_for_deletion.extend(refs);
234234

235235
self.marked_for_deletion
@@ -251,7 +251,7 @@ impl History {
251251
}
252252

253253
/// Add an invocation to the history.
254-
pub(crate) fn add_invocation(
254+
pub fn add_invocation(
255255
&mut self,
256256
seq: Seq,
257257
uses: Vec<Ref>,
@@ -306,7 +306,7 @@ impl History {
306306

307307
/// Propagate worker error to the invocation with the given Seq. This will also propagate
308308
/// to all seqs that depend on this seq directly or indirectly.
309-
pub(crate) fn propagate_exception(&mut self, seq: Seq, exception: Exception) {
309+
pub fn propagate_exception(&mut self, seq: Seq, exception: Exception) {
310310
let mut queue = vec![seq];
311311
let mut visited = HashSet::new();
312312

@@ -364,13 +364,13 @@ impl History {
364364
results
365365
}
366366

367-
pub(crate) fn report_deadline_missed(&mut self) {
367+
pub fn report_deadline_missed(&mut self) {
368368
if let Some((seq, time, _)) = self.deadline {
369369
self.deadline = Some((seq, time, true));
370370
}
371371
}
372372

373-
pub(crate) fn deadline(
373+
pub fn deadline(
374374
&mut self,
375375
expected_progress: u64,
376376
timeout: tokio::time::Duration,
@@ -397,7 +397,7 @@ impl History {
397397
self.deadline
398398
}
399399

400-
pub(crate) fn update_deadline_tracking(&mut self, rank: usize, seq: Seq) {
400+
pub fn update_deadline_tracking(&mut self, rank: usize, seq: Seq) {
401401
// rank_completed also calls this so that we stay up to date with client request_status messages.
402402
// However, controller request_status messages may be ahead of the client as the client may retain invocations
403403
// past the time completed so we should take the max
@@ -411,7 +411,7 @@ impl History {
411411

412412
/// Mark the given rank as completed up to but excluding the given Seq. This will also purge history for
413413
/// any Seqs that are no longer relevant (completed on all ranks).
414-
pub(crate) fn rank_completed(
414+
pub fn rank_completed(
415415
&mut self,
416416
rank: usize,
417417
seq: Seq,

controller/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#![allow(unsafe_op_in_unsafe_fn)]
1313

1414
pub mod bootstrap;
15-
mod history;
15+
pub mod history;
1616

1717
use std::collections::HashMap;
1818
use std::collections::HashSet;

hyperactor_mesh/src/actor_mesh.rs

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ use ndslice::Range;
3030
use ndslice::Selection;
3131
use ndslice::Shape;
3232
use ndslice::ShapeError;
33+
use ndslice::Slice;
3334
use serde::Deserialize;
3435
use serde::Serialize;
3536

@@ -147,6 +148,75 @@ impl<'a, A: RemoteActor> RootActorMesh<'a, A> {
147148
pub(crate) fn open_port<M: Message>(&self) -> (PortHandle<M>, PortReceiver<M>) {
148149
self.proc_mesh.client().open_port()
149150
}
151+
152+
/// Cast an [`M`]-typed message to the ranks selected by `sel`
153+
/// in this ActorMesh.
154+
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
155+
pub fn cast<M: RemoteMessage + Clone>(
156+
&self,
157+
selection: Selection,
158+
message: M,
159+
) -> Result<(), CastError>
160+
where
161+
A: RemoteHandles<Cast<M>> + RemoteHandles<IndexedErasedUnbound<Cast<M>>>,
162+
{
163+
let _ = metrics::ACTOR_MESH_CAST_DURATION.start(hyperactor::kv_pairs!(
164+
"message_type" => M::typename(),
165+
"message_variant" => message.arm().unwrap_or_default(),
166+
));
167+
let message = Cast {
168+
rank: CastRank(usize::MAX),
169+
shape: self.shape().clone(),
170+
message,
171+
};
172+
let message = CastMessageEnvelope::new(
173+
self.proc_mesh.client().actor_id().clone(),
174+
DestinationPort::new::<A, Cast<M>>(self.name.clone()),
175+
message,
176+
None, // TODO: reducer typehash
177+
)?;
178+
179+
self.proc_mesh.comm_actor().send(
180+
self.proc_mesh.client(),
181+
CastMessage {
182+
dest: Uslice {
183+
slice: self.shape().slice().clone(),
184+
selection,
185+
},
186+
message,
187+
},
188+
)?;
189+
Ok(())
190+
}
191+
192+
/// Until the selection logic is more powerful, we need a way to
193+
/// replicate the send patterns that the worker actor mesh actually does.
194+
pub fn cast_slices<M: RemoteMessage + Clone>(
195+
&self,
196+
sel: Vec<Slice>,
197+
message: M,
198+
) -> Result<(), CastError>
199+
where
200+
A: RemoteHandles<Cast<M>> + RemoteHandles<IndexedErasedUnbound<Cast<M>>>,
201+
{
202+
let _ = metrics::ACTOR_MESH_CAST_DURATION.start(hyperactor::kv_pairs!(
203+
"message_type" => M::typename(),
204+
"message_variant" => message.arm().unwrap_or_default(),
205+
));
206+
for ref slice in sel {
207+
for rank in slice.iter() {
208+
let cast = Cast {
209+
rank: CastRank(rank),
210+
shape: self.shape().clone(),
211+
message: message.clone(),
212+
};
213+
self.ranks[rank]
214+
.send(self.proc_mesh.client(), cast)
215+
.map_err(|err| CastError::MailboxSenderError(rank, err))?;
216+
}
217+
}
218+
Ok(())
219+
}
150220
}
151221

152222
#[async_trait]

hyperactor_mesh/src/proc_mesh/mod.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,14 @@ impl ProcMesh {
371371
&self.client
372372
}
373373

374+
pub fn client_proc(&self) -> &Proc {
375+
&self.client_proc
376+
}
377+
378+
pub fn proc_id(&self) -> &ProcId {
379+
self.client_proc.proc_id()
380+
}
381+
374382
/// An event stream of proc events. Each ProcMesh can produce only one such
375383
/// stream, returning None after the first call.
376384
pub fn events(&mut self) -> Option<ProcEvents> {

monarch_extension/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ clap = { version = "4.5.38", features = ["derive", "env", "string", "unicode", "
1919
controller = { version = "0.0.0", path = "../controller" }
2020
hyperactor = { version = "0.0.0", path = "../hyperactor" }
2121
hyperactor_extension = { version = "0.0.0", path = "../hyperactor_extension" }
22+
hyperactor_mesh = { version = "0.0.0", path = "../hyperactor_mesh" }
2223
hyperactor_multiprocess = { version = "0.0.0", path = "../hyperactor_multiprocess" }
2324
monarch_hyperactor = { version = "0.0.0", path = "../monarch_hyperactor" }
2425
monarch_messages = { version = "0.0.0", path = "../monarch_messages" }

monarch_extension/src/client.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,17 @@ use crate::controller::PyRanks;
5151
use crate::convert::convert;
5252

5353
#[pyclass(frozen, module = "monarch._rust_bindings.monarch_extension.client")]
54-
struct WorkerResponse {
54+
pub struct WorkerResponse {
5555
seq: Seq,
5656
result: Option<Result<Serialized, Exception>>,
5757
}
5858

59+
impl WorkerResponse {
60+
pub fn new(seq: Seq, result: Option<Result<Serialized, Exception>>) -> Self {
61+
Self { seq, result }
62+
}
63+
}
64+
5965
#[pymethods]
6066
impl WorkerResponse {
6167
#[staticmethod]
@@ -510,7 +516,7 @@ pub struct DebuggerMessage {
510516
impl DebuggerMessage {
511517
#[new]
512518
#[pyo3(signature = (*, debugger_actor_id, action))]
513-
fn new(debugger_actor_id: PyActorId, action: DebuggerAction) -> PyResult<Self> {
519+
pub fn new(debugger_actor_id: PyActorId, action: DebuggerAction) -> PyResult<Self> {
514520
Ok(Self {
515521
debugger_actor_id,
516522
action,

monarch_extension/src/lib.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ mod client;
1212
mod controller;
1313
pub mod convert;
1414
mod debugger;
15+
mod mesh_controller;
1516
mod panic;
1617
mod simulator_client;
1718
mod tensor_worker;
@@ -150,5 +151,10 @@ pub fn mod_init(module: &Bound<'_, PyModule>) -> PyResult<()> {
150151
"monarch_extension.panic",
151152
)?)?;
152153

154+
crate::mesh_controller::register_python_bindings(&get_or_add_new_module(
155+
module,
156+
"monarch_extension.mesh_controller",
157+
)?)?;
158+
153159
Ok(())
154160
}

0 commit comments

Comments
 (0)