diff --git a/controller/src/lib.rs b/controller/src/lib.rs index e47610e7..09ce6d28 100644 --- a/controller/src/lib.rs +++ b/controller/src/lib.rs @@ -612,7 +612,6 @@ mod tests { use hyperactor::RefClient; use hyperactor::channel; use hyperactor::channel::ChannelTransport; - use hyperactor::channel::sim; use hyperactor::channel::sim::SimAddr; use hyperactor::clock::Clock; use hyperactor::clock::RealClock; @@ -631,6 +630,7 @@ mod tests { use hyperactor::reference::GangId; use hyperactor::reference::ProcId; use hyperactor::reference::WorldId; + use hyperactor::simnet; use hyperactor_mesh::comm::CommActorParams; use hyperactor_multiprocess::System; use hyperactor_multiprocess::proc_actor::ProcMessage; @@ -1549,6 +1549,15 @@ mod tests { // Start system actor. let system_addr = ChannelAddr::any(ChannelTransport::Unix); let proxy_addr = ChannelAddr::any(ChannelTransport::Unix); + simnet::start( + ChannelAddr::any(ChannelTransport::Unix), + proxy_addr.clone(), + 1000, + ) + .unwrap(); + simnet::simnet_handle() + .unwrap() + .set_training_script_state(simnet::TrainingScriptState::Waiting); let system_sim_addr = ChannelAddr::Sim(SimAddr::new(system_addr, proxy_addr.clone()).unwrap()); @@ -1647,7 +1656,7 @@ mod tests { assert_eq!(result.0, Seq::default()); assert!(result.1.expect("result").is_err()); - let records = sim::records().await.unwrap(); + let records = simnet::simnet_handle().unwrap().close().await.unwrap(); eprintln!("{}", serde_json::to_string_pretty(&records).unwrap()); } #[tokio::test] diff --git a/hyperactor/src/channel/sim.rs b/hyperactor/src/channel/sim.rs index 96496d3c..f7c2d6a0 100644 --- a/hyperactor/src/channel/sim.rs +++ b/hyperactor/src/channel/sim.rs @@ -15,11 +15,12 @@ use std::marker::PhantomData; use std::sync::Arc; use dashmap::DashMap; +use rand::RngCore; use regex::Regex; use tokio::sync::Mutex; use super::*; -use crate::PortId; +use crate::ActorId; use crate::channel; use crate::clock::Clock; use crate::clock::RealClock; @@ -30,8 +31,8 @@ use crate::mailbox::MessageEnvelope; use crate::simnet; use crate::simnet::Dispatcher; use crate::simnet::Event; -use crate::simnet::ProxyMessage; -use crate::simnet::ScheduledEvent; +use crate::simnet::PerfettoFlow; +use crate::simnet::PerfettoTrace; use crate::simnet::SimNetConfig; use crate::simnet::SimNetEdge; use crate::simnet::SimNetError; @@ -41,7 +42,6 @@ lazy_static! { static ref SENDER: SimDispatcher = SimDispatcher::default(); } static SIM_LINK_BUF_SIZE: usize = 256; -static CLIENT_ADDRESS: &str = "unix!@client"; #[derive( Clone, @@ -154,42 +154,103 @@ impl fmt::Display for SimAddr { } } +fn parse_message(input: &str) -> String { + let open_brace_pos = match input.find('{') { + Some(pos) => pos, + None => return input.to_string(), + }; + + let first_part = input[..open_brace_pos].to_string(); + + let start_quote_pos = match input[open_brace_pos + 1..].find('"') { + Some(pos) => pos + open_brace_pos + 1, + None => return input.to_string(), + }; + + let end_quote_pos = match input[start_quote_pos + 1..].find('"') { + Some(pos) => pos + start_quote_pos + 1, + None => return input.to_string(), + }; + + let second_part = input[start_quote_pos + 1..end_quote_pos].to_string(); + + format!("{}::{}", first_part, second_part) +} + /// Message Event that can be passed around in the simnet. #[derive(Debug)] -pub(crate) struct MessageDeliveryEvent { +pub(crate) struct MessageSendEvent { src_addr: Option, dest_addr: AddressProxyPair, data: Serialized, duration_ms: u64, + inflight_time_ms: u64, + // Used to match a pair of send/recv events. + id: String, + sender_actor: ActorId, + dest_actor: ActorId, + message_type: String, } -impl MessageDeliveryEvent { - /// Creates a new MessageDeliveryEvent. +impl MessageSendEvent { + /// Creates a new MessageSendEvent. pub fn new( src_addr: Option, dest_addr: AddressProxyPair, data: Serialized, ) -> Self { + let (sender_actor, dest_actor, message_type) = + if let Ok(envelope) = data.deserialized::() { + let msg_string = envelope.data().to_string(); + let parsed = parse_message(&msg_string); + ( + envelope.sender().clone(), + envelope.dest().actor_id().clone(), + parsed, + ) + } else { + ( + id!(unknown[0].unknown), + id!(unknown[0].unknown), + "unknown".to_string(), + ) + }; + Self { src_addr, dest_addr, data, - duration_ms: 100, + duration_ms: 1, + inflight_time_ms: 1, + id: format!("0x{}", random_hex_str()), + sender_actor, + dest_actor, + message_type, } } } #[async_trait] -impl Event for MessageDeliveryEvent { +impl Event for MessageSendEvent { async fn handle(&self) -> Result<(), SimNetError> { - // Send the message to the correct receiver. - SENDER - .send( - self.src_addr.clone(), - self.dest_addr.clone(), - self.data.clone(), - ) - .await?; + let inflight_time_ms = self.inflight_time_ms; + let event = Box::new(MessageRecvEvent::new( + self.src_addr.clone(), + self.dest_addr.clone(), + self.data.clone(), + Some(self.id.clone()), + )); + + tokio::task::spawn(async move { + SimClock + .sleep(tokio::time::Duration::from_millis(inflight_time_ms)) + .await; + + if let Ok(handle) = simnet_handle() { + let _ = handle.send_event(event); + } + }); + Ok(()) } @@ -197,23 +258,13 @@ impl Event for MessageDeliveryEvent { self.duration_ms } - fn summary(&self) -> String { - format!( - "Sending message from {} to {}", - self.src_addr - .as_ref() - .map_or("unknown".to_string(), |addr| addr.address.to_string()), - self.dest_addr.address.clone() - ) - } - async fn read_simnet_config(&mut self, topology: &Arc>) { if let Some(src_addr) = &self.src_addr { let edge = SimNetEdge { src: src_addr.address.clone(), dst: self.dest_addr.address.clone(), }; - self.duration_ms = topology + self.inflight_time_ms = topology .lock() .await .topology @@ -221,11 +272,112 @@ impl Event for MessageDeliveryEvent { .map_or_else(|| 1, |v| v.latency.as_millis() as u64); } } + + fn to_perfetto(&self, start: u64, _end: u64) -> Option { + Some(PerfettoTrace { + name: format!("send {}", self.message_type), + cat: "message".to_string(), + ph: "X".to_string(), + ts: start * 1000, + dur: 1000, + actor_id: self.sender_actor.clone(), + bind_id: Some(self.id.clone()), + flow: Some(PerfettoFlow::Out), + }) + } +} + +/// Message Recv Event that can be passed around in the simnet. +#[derive(Debug)] +pub(crate) struct MessageRecvEvent { + src_addr: Option, + dest_addr: AddressProxyPair, + data: Serialized, + duration_ms: u64, + // Used to match a pair of send/recv events. + id: Option, + sender_actor: ActorId, + dest_actor: ActorId, + message_type: String, +} + +impl MessageRecvEvent { + /// Creates a new MessageRecvEvent. + pub fn new( + src_addr: Option, + dest_addr: AddressProxyPair, + data: Serialized, + id: Option, + ) -> Self { + let (sender_actor, dest_actor, message_type) = + if let Ok(envelope) = data.deserialized::() { + let msg_string = envelope.data().to_string(); + let parsed = parse_message(&msg_string); + ( + envelope.sender().clone(), + envelope.dest().actor_id().clone(), + parsed, + ) + } else { + ( + hyperactor::id!(unknown[0].unknown), + hyperactor::id!(unknown[0].unknown), + "unknown".to_string(), + ) + }; + + Self { + src_addr, + dest_addr, + data, + duration_ms: 1, + id, + sender_actor, + dest_actor, + message_type, + } + } +} + +#[async_trait] +impl Event for MessageRecvEvent { + async fn handle(&self) -> Result<(), SimNetError> { + // Send the message to the correct receiver. + SENDER + .send( + self.src_addr.clone(), + self.dest_addr.clone(), + self.data.clone(), + ) + .await?; + Ok(()) + } + + fn duration_ms(&self) -> u64 { + self.duration_ms + } + + fn to_perfetto(&self, start: u64, _end: u64) -> Option { + Some(PerfettoTrace { + name: format!("recv {}", self.message_type), + cat: "message".to_string(), + ph: "X".to_string(), + ts: start * 1000, + dur: 1000, + actor_id: self.dest_actor.clone(), + bind_id: self.id.clone(), + flow: Some(PerfettoFlow::In), + }) + } } -/// Export the message delivery records of the simnet. -pub async fn records() -> anyhow::Result>, SimNetError> { - Ok(simnet_handle()?.records().await) +fn random_hex_str() -> String { + let mut bytes = vec![0u8; 10 / 2]; + rand::thread_rng().fill_bytes(&mut bytes); + bytes + .into_iter() + .map(|b| format!("{:02x}", b)) + .collect::() } /// Bind a channel address to the simnet. It will register the address as a node in simnet, @@ -329,31 +481,10 @@ fn is_external_addr(addr: &AddressProxyPair) -> anyhow::Result for SimDispatcher { async fn send( &self, - src_addr: Option, + _src_addr: Option, addr: AddressProxyPair, data: Serialized, ) -> Result<(), SimNetError> { - if is_external_addr(&addr)? { - let dst_proxy = addr.proxy.clone(); - let sender = self - .sender_cache - .entry(dst_proxy.clone()) - .or_insert_with(|| create_egress_sender(dst_proxy.clone()).unwrap()); - let forward_message = ProxyMessage::new(src_addr.clone(), Some(addr.clone()), data); - let serialized_forward_message = match Serialized::serialize(&forward_message) { - Ok(data) => data, - Err(err) => return Err(SimNetError::InvalidArg(err.to_string())), - }; - // Here we use mailbox to deliver the ForwardMessage. But it's higher level than - // the simnet. So there are unused placeholder here which is not ideal. - let port_id_placeholder = PortId(id!(unused_world[0].unused_actor), 0); - let message = - MessageEnvelope::new_unknown(port_id_placeholder, serialized_forward_message); - return sender - .try_post(message, oneshot::channel().0) - .map_err(|err| SimNetError::InvalidNode(addr.address.to_string(), err.into())); - } - self.dispatchers .get(&addr.address) .ok_or_else(|| { @@ -402,16 +533,25 @@ impl Tx for SimTx { }; match simnet_handle() { Ok(handle) => match &self.src_addr { - Some(src_addr) if src_addr.address.to_string() == CLIENT_ADDRESS => handle - .send_scheduled_event(ScheduledEvent { - event: Box::new(MessageDeliveryEvent::new( - self.src_addr.clone(), - self.dst_addr.clone(), - data, - )), - time: SimClock.millis_since_start(RealClock.now()), - }), - _ => handle.send_event(Box::new(MessageDeliveryEvent::new( + Some(src_addr) if src_addr.proxy != *handle.proxy_addr() => { + let event = Box::new(MessageRecvEvent::new( + self.src_addr.clone(), + self.dst_addr.clone(), + data, + None, + )); + let recv_time = RealClock.now(); + + tokio::task::spawn(async move { + SimClock.sleep_until(recv_time).await; + + if let Ok(handle) = simnet_handle() { + let _ = handle.send_event(event); + } + }); + Ok(()) + } + _ => handle.send_event(Box::new(MessageSendEvent::new( self.src_addr.clone(), self.dst_addr.clone(), data, @@ -493,10 +633,13 @@ mod tests { use std::iter::zip; use super::*; + use crate::PortId; use crate::clock::Clock; use crate::clock::RealClock; use crate::clock::SimClock; + use crate::id; use crate::simnet::NetworkConfig; + use crate::simnet::ProxyMessage; use crate::simnet::start; #[tokio::test] @@ -539,84 +682,10 @@ mod tests { assert_eq!(rx.recv().await.unwrap(), 123); } - let records = sim::records().await; + let records = sim::simnet_handle().unwrap().close().await.unwrap(); eprintln!("records: {:#?}", records); } - #[tokio::test] - async fn test_send_egress_message() { - let proxy = ChannelAddr::any(ChannelTransport::Unix); - start( - ChannelAddr::any(ChannelTransport::Unix), - proxy.clone(), - 1000, - ) - .unwrap(); - - // Serve an external proxy channel to receive the egress message. - let egress_addr = ChannelAddr::any(ChannelTransport::Unix); - let dispatcher = SimDispatcher::default(); - let (_, mut rx) = channel::serve::(egress_addr.clone()) - .await - .unwrap(); - // just a random port ID - let port_id = PortId(id!(test[0].actor0), 0); - let msg = MessageEnvelope::new_unknown( - port_id.clone(), - Serialized::serialize(&"hola".to_string()).unwrap(), - ); - // The sim addr we want simnet to send message to, it should have the egress_addr - // as the proxy address of dst. - let src_addr = AddressProxyPair { - address: "unix!@src".parse::().unwrap(), - proxy: "unix!@proxy".parse::().unwrap(), - }; - let egress_addr = AddressProxyPair { - address: "unix!@dst".parse::().unwrap(), - proxy: egress_addr, - }; - let serialized_msg = Serialized::serialize(&msg).unwrap(); - dispatcher - .send( - Some(src_addr.clone()), - egress_addr.clone(), - serialized_msg.clone(), - ) - .await - .unwrap(); - let received_msg = rx.recv().await.unwrap(); - let actual_forward_msg: ProxyMessage = received_msg.deserialized().unwrap(); - let expected_forward_msg = ProxyMessage::new( - Some(src_addr.clone()), - Some(egress_addr.clone()), - serialized_msg, - ); - - assert_eq!(actual_forward_msg, expected_forward_msg); - - // Sending the message again should work by using the cached sender. - // But it's impl detail, not verified here. We just verify that it - // can send a different message. - let msg = MessageEnvelope::new_unknown( - port_id, - Serialized::serialize(&"ciao".to_string()).unwrap(), - ); - let serialized_msg = Serialized::serialize(&msg).unwrap(); - dispatcher - .send( - Some(src_addr.clone()), - egress_addr.clone(), - serialized_msg.clone(), - ) - .await - .unwrap(); - let received_msg = rx.recv().await.unwrap(); - let actual_forward_msg: ProxyMessage = received_msg.deserialized().unwrap(); - let expected_forward_msg = - ProxyMessage::new(Some(src_addr), Some(egress_addr), serialized_msg); - assert_eq!(actual_forward_msg, expected_forward_msg); - } - #[tokio::test] async fn test_invalid_sim_addr() { let src = "sim!src"; @@ -698,16 +767,23 @@ mod tests { // This message will be delievered at simulator time = 100 seconds tx.try_post((), oneshot::channel().0).unwrap(); { - // Allow some time for simnet to run - RealClock.sleep(tokio::time::Duration::from_secs(1)).await; + // Allow simnet to run + tokio::task::yield_now().await; // Messages have not been receive since 10 seconds have not elapsed assert!(rx.rx.try_recv().is_err()); } - // Advance "real" time by 100 seconds - tokio::time::advance(tokio::time::Duration::from_secs(100)).await; + tokio::time::advance( + // Advance "real" time by 1 ms for send time + tokio::time::Duration::from_millis(1) + // Advance "real" time by 100 seconds for inflight time + + tokio::time::Duration::from_secs(100) + // Advance "real" time by 1 ms for recv time + + tokio::time::Duration::from_millis(1), + ) + .await; { // Allow some time for simnet to run - RealClock.sleep(tokio::time::Duration::from_secs(1)).await; + tokio::task::yield_now().await; // Messages are received assert!(rx.rx.try_recv().is_ok()); } @@ -723,21 +799,11 @@ mod tests { 1000, ) .unwrap(); - let controller_to_dst = SimAddr::new_with_src( - AddressProxyPair { - address: "unix!@controller".parse::().unwrap(), - proxy: proxy_addr.clone(), - }, - "unix!@dst".parse::().unwrap(), - proxy_addr.clone(), - ) - .unwrap(); - let controller_tx = sim::dial::<()>(controller_to_dst.clone()).unwrap(); let client_to_dst = SimAddr::new_with_src( AddressProxyPair { - address: "unix!@client".parse::().unwrap(), - proxy: proxy_addr.clone(), + address: ChannelAddr::any(ChannelTransport::Unix), + proxy: ChannelAddr::any(ChannelTransport::Unix), }, "unix!@dst".parse::().unwrap(), proxy_addr.clone(), @@ -745,35 +811,27 @@ mod tests { .unwrap(); let client_tx = sim::dial::<()>(client_to_dst).unwrap(); - // 1 second of latency - let simnet_config_yaml = r#" - edges: - - src: unix!@controller - dst: unix!@dst - metadata: - latency: 1 - "#; - update_config(NetworkConfig::from_yaml(simnet_config_yaml).unwrap()) - .await - .unwrap(); - assert_eq!(SimClock.millis_since_start(RealClock.now()), 0); // Fast forward real time to 5 seconds tokio::time::advance(tokio::time::Duration::from_secs(5)).await; { // Send client message client_tx.try_post((), oneshot::channel().0).unwrap(); - // Send system message - controller_tx.try_post((), oneshot::channel().0).unwrap(); + tokio::time::advance(tokio::time::Duration::from_millis(1)).await; // Allow some time for simnet to run - RealClock.sleep(tokio::time::Duration::from_secs(1)).await; + tokio::task::yield_now().await; } - let recs = records().await.unwrap().unwrap(); - assert_eq!(recs.len(), 2); - let end_times = recs.iter().map(|rec| rec.end_at).collect::>(); + let recs = simnet::simnet_handle().unwrap().close().await.unwrap(); + let recs = recs + .as_array() + .unwrap() + .iter() + .filter(|r| r["ph"] != "M") + .collect::>(); // client message was delivered at "real" time = 5 seconds - assert!(end_times.contains(&5000)); - // system message was delivered at simulated time = 1 second - assert!(end_times.contains(&1000)); + assert_eq!( + recs.first().map(|rec| rec["ts"].as_u64().unwrap()).unwrap(), + 5000 * 1000 + ); } } diff --git a/hyperactor/src/simnet.rs b/hyperactor/src/simnet.rs index d2f668fb..3a39b9ff 100644 --- a/hyperactor/src/simnet.rs +++ b/hyperactor/src/simnet.rs @@ -13,6 +13,7 @@ //! testing and development of message distribution techniques. use std::collections::BTreeMap; +use std::collections::HashMap; use std::fmt::Debug; use std::hash::Hash; use std::sync::Arc; @@ -48,12 +49,13 @@ use crate::ActorId; use crate::Mailbox; use crate::Named; use crate::OncePortRef; +use crate::ProcId; use crate::WorldId; use crate::channel; use crate::channel::ChannelAddr; use crate::channel::Rx; use crate::channel::sim::AddressProxyPair; -use crate::channel::sim::MessageDeliveryEvent; +use crate::channel::sim::MessageRecvEvent; use crate::clock::Clock; use crate::clock::RealClock; use crate::clock::SimClock; @@ -86,7 +88,7 @@ type SimulatorTimeInstant = u64; /// The unit of execution for the simulator. /// Using handle(), simnet can schedule executions in the network. /// If you want to send a message for example, you would want to implement -/// a MessageDeliveryEvent much on the lines expressed in simnet tests. +/// a MessageSendEvent much on the lines expressed in simnet tests. /// You can also do other more advanced concepts such as node churn, /// or even simulate process spawns in a distributed system. For example, /// one can implement a SystemActorSimEvent in order to spawn a system @@ -115,8 +117,10 @@ pub trait Event: Send + Sync + Debug { /// Read the simnet config and update self accordingly. async fn read_simnet_config(&mut self, _topology: &Arc>) {} - /// A user-friendly summary of the event - fn summary(&self) -> String; + /// The event as a Perfetto trace. + fn to_perfetto(&self, _start: u64, _end: u64) -> Option { + None + } } /// This is a simple event that is used to join a node to the network. @@ -140,10 +144,6 @@ impl Event for NodeJoinEvent { fn duration_ms(&self) -> u64 { 0 } - - fn summary(&self) -> String { - "Node join".into() - } } #[derive(Debug)] @@ -180,10 +180,6 @@ impl Event for SleepEvent { fn duration_ms(&self) -> u64 { self.duration_ms } - - fn summary(&self) -> String { - format!("Sleeping for {} ms", self.duration_ms) - } } #[derive(Debug)] @@ -212,19 +208,20 @@ impl Event for TorchOpEvent { } fn duration_ms(&self) -> u64 { - 100 + 2 } - fn summary(&self) -> String { - let kwargs_string = if self.kwargs_string.is_empty() { - "".to_string() - } else { - format!(", {}", self.kwargs_string) - }; - format!( - "[{}] Torch Op: {}({}{})", - self.worker_actor_id, self.op, self.args_string, kwargs_string - ) + fn to_perfetto(&self, start: u64, end: u64) -> Option { + Some(PerfettoTrace { + name: self.op.clone(), + cat: "compute".to_string(), + ph: "X".to_string(), + ts: start * 1000, + dur: (end - start) * 1000, + actor_id: self.worker_actor_id.clone(), + bind_id: None, + flow: None, + }) } } @@ -353,10 +350,9 @@ pub enum TrainingScriptState { /// A handle to a running [`SimNet`] instance. pub struct SimNetHandle { - join_handles: Arc>>>, + join_handle: Mutex>>>, event_tx: UnboundedSender<(Box, bool, Option)>, config: Arc>, - records: Option>>>, pending_event_count: Arc, /// Handle to a running proxy server that forwards external messages /// into the simnet. @@ -396,19 +392,6 @@ impl SimNetHandle { self.send_event_impl(event, false) } - /// Sends an event that already has a scheduled time onto the simnet's event loop - #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`. - pub(crate) fn send_scheduled_event( - &self, - ScheduledEvent { event, time }: ScheduledEvent, - ) -> Result<(), SimNetError> { - self.pending_event_count - .fetch_add(1, std::sync::atomic::Ordering::SeqCst); - self.event_tx - .send((event, true, Some(time))) - .map_err(|err| SimNetError::Closed(err.to_string())) - } - /// Let the simnet know if the training script is running or waiting for the backend /// to return a future result. pub fn set_training_script_state(&self, state: TrainingScriptState) { @@ -433,25 +416,20 @@ impl SimNetHandle { /// Close the simulator, processing pending messages before /// completing the returned future. - pub async fn close(&self) -> Result<(), JoinError> { + pub async fn close(&self) -> Result { // Stop the proxy if there is one. self.proxy_handle.stop().await?; // Signal the simnet loop to stop self.stop_signal.store(true, Ordering::SeqCst); - let mut handles = self.join_handles.lock().await; - for handle in handles.drain(..) { - handle.await?; - } - Ok(()) - } - /// Get a copy of records of message deliveries in the simulation. - pub async fn records(&self) -> Option> { - if let Some(records) = &self.records { - Some(records.lock().await.clone()) + let mut guard = self.join_handle.lock().await; + let records = if let Some(handle) = guard.take() { + handle.await } else { - None - } + Ok(vec![]) + }?; + + Ok(serde_json::Value::Array(records)) } /// Update the network configuration to SimNet. @@ -561,6 +539,62 @@ impl SimOperation { } } +/// Represents the direction of a flow in a Perfetto trace +pub enum PerfettoFlow { + /// Indicates an incoming flow to a trace event + /// Adds the field `"flow_in": true` when converting to JSON + In, + /// Indicates an outgoing flow from a trace event + /// Adds the field `"flow_out": true` when converting to JSON + Out, +} + +/// Represents a trace event in the Perfetto tracing format +pub struct PerfettoTrace { + /// The name of the trace event + pub name: String, + /// The category of the trace event + pub cat: String, + /// The phase of the trace event (e.g., "X" for complete events) + pub ph: String, + /// The timestamp of the trace event in microseconds + pub ts: u64, + /// The duration of the trace event in microseconds + pub dur: u64, + /// The actor ID associated with this trace event + /// When converting to JSON this will be resolved to the appropriate + /// `pid` and `tid` + pub actor_id: ActorId, + /// Optional binding ID for connecting related trace events + pub bind_id: Option, + /// Optional flow direction for flow events + pub flow: Option, +} + +impl PerfettoTrace { + fn to_json(&self, pid: usize, tid: usize) -> serde_json::Value { + let mut json = serde_json::json!({ + "name": self.name, + "cat": self.cat, + "ph": self.ph, + "ts": self.ts, + "dur": self.dur, + "pid": pid, + "tid": tid, + }); + if let Some(flow) = &self.flow { + match flow { + PerfettoFlow::In => json["flow_in"] = serde_json::Value::Bool(true), + PerfettoFlow::Out => json["flow_out"] = serde_json::Value::Bool(true), + } + } + if let Some(bind_id) = &self.bind_id { + json["bind_id"] = serde_json::Value::String(bind_id.to_string()); + } + json + } +} + #[async_trait] impl Event for SimOperation { async fn handle(&self) -> Result<(), SimNetError> { @@ -572,10 +606,6 @@ impl Event for SimOperation { fn duration_ms(&self) -> u64 { 0 } - - fn summary(&self) -> String { - format!("SimOperation: {:?}", self.operational_message) - } } /// A ProxyMessage is a message that SimNet proxy receives. @@ -653,9 +683,11 @@ pub struct SimNet { address_book: DashSet, state: State, max_latency: Duration, - records: Option>>>, + records: Vec, // number of events that has been received but not yet processed. pending_event_count: Arc, + pids: HashMap, + tids: HashMap, } /// A proxy to bridge external nodes and the SimNet. @@ -691,10 +723,11 @@ impl ProxyHandle { if let Ok(Ok(msg)) = timeout(Duration::from_millis(100), rx.recv()).await { let proxy_message: ProxyMessage = msg.deserialized().unwrap(); let event: Box = match proxy_message.dest_addr { - Some(dest_addr) => Box::new(MessageDeliveryEvent::new( + Some(dest_addr) => Box::new(MessageRecvEvent::new( proxy_message.sender_addr, dest_addr, proxy_message.data, + None, )), None => { let operational_message: OperationalMessage = @@ -769,15 +802,15 @@ pub fn start( tokio::sync::watch::channel(TrainingScriptState::Running); let (event_tx, event_rx) = mpsc::unbounded_channel::<(Box, bool, Option)>(); - let records = Some(Arc::new(Mutex::new(vec![]))); // TODO remove optional let pending_event_count = Arc::new(AtomicUsize::new(0)); let stop_signal = Arc::new(AtomicBool::new(false)); - let simnet_join_handle = { - let records = records.clone(); + + let join_handle = Mutex::new(Some({ let config = config.clone(); let pending_event_count = pending_event_count.clone(); let stop_signal = stop_signal.clone(); - tokio::task::spawn_blocking(move || { + + tokio::spawn(async move { let mut net = SimNet { config, address_book, @@ -786,13 +819,15 @@ pub fn start( unadvanceable_scheduled_events: BTreeMap::new(), }, max_latency: Duration::from_millis(max_duration_ms), - records, + records: Vec::new(), pending_event_count, + pids: HashMap::new(), + tids: HashMap::new(), }; - block_on(net.run(event_rx, training_script_state_rx, stop_signal)); + net.run(event_rx, training_script_state_rx, stop_signal) + .await }) - }; - let join_handles = Arc::new(Mutex::new(vec![simnet_join_handle])); + })); let (operational_message_tx, operational_message_rx) = mpsc::unbounded_channel::(); @@ -805,10 +840,9 @@ pub fn start( .map_err(|err| SimNetError::ProxyNotAvailable(err.to_string()))?; HANDLE.get_or_init(|| SimNetHandle { - join_handles, + join_handle, event_tx, config, - records, pending_event_count, proxy_handle, operational_message_tx, @@ -867,14 +901,28 @@ impl SimNet { } /// Schedule the event into the network. - async fn schedule_event(&mut self, scheduled_event: ScheduledEvent, advanceable: bool) { - if let Some(records) = &self.records { - records.lock().await.push(SimulatorEventRecord { - summary: scheduled_event.event.summary(), - start_at: SimClock.millis_since_start(SimClock.now()), - end_at: scheduled_event.time, - }); + fn schedule_event(&mut self, scheduled_event: ScheduledEvent, advanceable: bool) { + if let Some(trace) = scheduled_event.event.to_perfetto( + SimClock.millis_since_start(SimClock.now()), + scheduled_event.time, + ) { + let (next_pid, next_tid) = (self.pids.len(), self.tids.len()); + + let pid = self + .pids + .entry(trace.actor_id.proc_id().clone()) + .or_insert_with(|| next_pid) + .clone(); + + let tid = self + .tids + .entry(trace.actor_id.clone()) + .or_insert_with(|| next_tid) + .clone(); + + self.records.push(trace.to_json(pid, tid)); } + if advanceable { self.state .scheduled_events @@ -890,6 +938,32 @@ impl SimNet { } } + fn make_metadata_traces(&self) -> Vec { + let mut metadata_traces = vec![]; + for (proc_id, pid) in self.pids.iter() { + metadata_traces.push(serde_json::json!({ + "ph": "M", + "name": "process_name", + "pid": pid, + "args": { + "name": proc_id.to_string(), + } + })) + } + for (actor_id, tid) in self.tids.iter() { + metadata_traces.push(serde_json::json!({ + "ph": "M", + "name": "thread_name", + "pid": self.pids.get(actor_id.proc_id()).unwrap_or(&0), + "tid": tid, + "args": { + "name": actor_id.to_string(), + } + })) + } + metadata_traces + } + /// Run the simulation. This will dispatch all the messages in the network. /// And wait for new ones. async fn run( @@ -897,7 +971,7 @@ impl SimNet { mut event_rx: UnboundedReceiver<(Box, bool, Option)>, training_script_state_rx: tokio::sync::watch::Receiver, stop_signal: Arc, - ) { + ) -> Vec { // The simulated number of milliseconds the training script // has spent waiting for the backend to resolve a future let mut training_script_waiting_time: u64 = 0; @@ -906,7 +980,8 @@ impl SimNet { 'outer: loop { // Check if we should stop if stop_signal.load(Ordering::SeqCst) { - break 'outer; + let metadata = self.make_metadata_traces(); + break 'outer self.records.drain(..).chain(metadata).collect::>(); } while let Ok((event, advanceable, time)) = event_rx.try_recv() { @@ -917,7 +992,7 @@ impl SimNet { }, None => self.create_scheduled_event(event).await, }; - self.schedule_event(scheduled_event, advanceable).await; + self.schedule_event(scheduled_event, advanceable); } { @@ -936,6 +1011,7 @@ impl SimNet { + training_script_waiting_time }) { + tokio::task::yield_now().await; continue; } match ( @@ -990,7 +1066,7 @@ impl SimNet { }, None => self.create_scheduled_event(event).await, }; - self.schedule_event(scheduled_event, advanceable).await; + self.schedule_event(scheduled_event, advanceable); }, _ = RealClock.sleep(Duration::from_millis(10)) => {} } @@ -1006,7 +1082,8 @@ impl SimNet { self.pending_event_count .fetch_sub(1, std::sync::atomic::Ordering::SeqCst); if scheduled_event.event.handle_network(self).await.is_err() { - break 'outer; + let metadata = self.make_metadata_traces(); + break 'outer self.records.drain(..).chain(metadata).collect::>(); } } } @@ -1035,17 +1112,6 @@ where s.parse().map_err(serde::de::Error::custom) } -/// DeliveryRecord is a structure to bookkeep the message events. -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct SimulatorEventRecord { - /// Event dependent summary for user - pub summary: String, - /// The time at which the message delivery was started. - pub start_at: SimulatorTimeInstant, - /// The time at which the message was delivered to the receiver. - pub end_at: SimulatorTimeInstant, -} - /// A configuration for the network topology. #[derive(Debug, Serialize, Deserialize)] pub struct NetworkConfig { @@ -1094,55 +1160,74 @@ mod tests { use crate::simnet::SimNetError; #[derive(Debug)] - struct MessageDeliveryEvent { + struct MessageSendEvent { src_addr: SimAddr, dest_addr: SimAddr, data: Serialized, duration_ms: u64, dispatcher: Option, + inflight_time_ms: u64, } #[async_trait] - impl Event for MessageDeliveryEvent { + impl Event for MessageSendEvent { async fn handle(&self) -> Result<(), simnet::SimNetError> { - if let Some(dispatcher) = &self.dispatcher { - dispatcher - .send( - Some(self.src_addr.clone()), - self.dest_addr.clone(), - self.data.clone(), - ) - .await?; - } + let inflight_time_ms = self.inflight_time_ms; + let event = Box::new(MessageRecvEvent { + src_addr: self.src_addr.clone(), + dest_addr: self.dest_addr.clone(), + data: self.data.clone(), + duration_ms: 0, + dispatcher: self.dispatcher.clone(), + }); + + tokio::task::spawn(async move { + SimClock + .sleep(tokio::time::Duration::from_millis(inflight_time_ms)) + .await; + + if let Ok(handle) = simnet_handle() { + let _ = handle.send_event(event); + } + }); Ok(()) } fn duration_ms(&self) -> u64 { self.duration_ms } - fn summary(&self) -> String { - format!( - "Sending message from {} to {}", - self.src_addr.addr().clone(), - self.dest_addr.addr().clone() - ) - } - async fn read_simnet_config(&mut self, config: &Arc>) { let edge = SimNetEdge { src: self.src_addr.addr().clone(), dst: self.dest_addr.addr().clone(), }; - self.duration_ms = config + self.inflight_time_ms = config .lock() .await .topology .get(&edge) .map_or_else(|| 1, |v| v.latency.as_millis() as u64); } + + fn to_perfetto(&self, start: u64, _end: u64) -> Option { + Some(PerfettoTrace { + name: format!( + "{} sending message to {}", + self.src_addr.addr(), + self.dest_addr.addr(), + ), + cat: "message".to_string(), + ph: "X".to_string(), + ts: start * 1000, + dur: self.duration_ms, + actor_id: id!(unknown[0].unknown), + bind_id: None, + flow: None, + }) + } } - impl MessageDeliveryEvent { + impl MessageSendEvent { fn new( src_addr: SimAddr, dest_addr: SimAddr, @@ -1153,9 +1238,54 @@ mod tests { src_addr, dest_addr, data, - duration_ms: 1, + duration_ms: 0, dispatcher, + inflight_time_ms: 1, + } + } + } + + #[derive(Debug)] + struct MessageRecvEvent { + src_addr: SimAddr, + dest_addr: SimAddr, + data: Serialized, + duration_ms: u64, + dispatcher: Option, + } + + #[async_trait] + impl Event for MessageRecvEvent { + async fn handle(&self) -> Result<(), simnet::SimNetError> { + if let Some(dispatcher) = &self.dispatcher { + dispatcher + .send( + Some(self.src_addr.clone()), + self.dest_addr.clone(), + self.data.clone(), + ) + .await?; } + Ok(()) + } + fn duration_ms(&self) -> u64 { + self.duration_ms + } + fn to_perfetto(&self, start: u64, end: u64) -> Option { + Some(PerfettoTrace { + name: format!( + "{} received message from {}", + self.dest_addr.addr(), + self.src_addr.addr(), + ), + cat: "message".to_string(), + ph: "X".to_string(), + ts: start * 1000, + dur: self.duration_ms, + actor_id: id!(unknown[0].unknown), + bind_id: None, + flow: None, + }) } } @@ -1244,7 +1374,7 @@ mod tests { let proxy_addr = ChannelAddr::any(channel::ChannelTransport::Unix); let alice = SimAddr::new(alice, proxy_addr.clone()).unwrap(); let bob = SimAddr::new(bob, proxy_addr.clone()).unwrap(); - let msg = Box::new(MessageDeliveryEvent::new( + let msg = Box::new(MessageSendEvent::new( alice, bob, Serialized::serialize(&"123".to_string()).unwrap(), @@ -1256,14 +1386,24 @@ mod tests { .flush(Duration::from_secs(30)) .await .unwrap(); - let records = simnet_handle().unwrap().records().await; - let expected_record = SimulatorEventRecord { - summary: "Sending message from local!1 to local!2".to_string(), - start_at: 0, - end_at: latency.as_millis() as u64, - }; - assert!(records.as_ref().unwrap().len() == 1); - assert_eq!(records.unwrap().first().unwrap(), &expected_record); + let records = simnet_handle().unwrap().close().await.unwrap(); + let records = records + .as_array() + .unwrap() + .iter() + .filter(|r| r["ph"] != "M") + .collect::>(); + let expected_record = serde_json::json!({ + "cat": "message", + "dur": 0, + "name": "local!2 received message from local!1".to_string(), + "ph": "X", + "pid": 0, + "tid": 0, + "ts": 1000 * 1000, + }); + assert!(records.len() == 2, "{:?}", records); + assert_eq!(*records.last().unwrap(), &expected_record); } #[tokio::test] @@ -1300,7 +1440,7 @@ mod tests { for _ in 0..10 { simnet_handle() .unwrap() - .send_event(Box::new(MessageDeliveryEvent::new( + .send_event(Box::new(MessageSendEvent::new( alice.clone(), bob.clone(), Serialized::serialize(&"123".to_string()).unwrap(), @@ -1316,14 +1456,21 @@ mod tests { .await .unwrap(); - let records = simnet_handle().unwrap().records().await; - assert_eq!(records.as_ref().unwrap().len(), 10); + let records = simnet_handle().unwrap().close().await.unwrap(); + let records = records + .as_array() + .unwrap() + .iter() + .filter(|r| r["ph"] != "M" && r["name"].as_str().unwrap().contains("received")) + .collect::>(); + assert_eq!(records.len(), 10); // If debounce is successful, the simnet will not advance to the delivery of any of // the messages before all are received + let last_record = records.last().unwrap(); assert_eq!( - records.unwrap().last().unwrap().end_at, - latency.as_millis() as u64 + last_record["ts"].as_u64().unwrap() + last_record["dur"].as_u64().unwrap(), + latency.as_micros() as u64, ); } @@ -1357,19 +1504,19 @@ mod tests { let addr_1 = SimAddr::new(addresses[1].clone(), proxy_addr.clone()).unwrap(); let addr_2 = SimAddr::new(addresses[2].clone(), proxy_addr.clone()).unwrap(); let addr_3 = SimAddr::new(addresses[3].clone(), proxy_addr.clone()).unwrap(); - let one = Box::new(MessageDeliveryEvent::new( + let one = Box::new(MessageSendEvent::new( addr_0.clone(), addr_1.clone(), messages[0].clone(), sender.clone(), )); - let two = Box::new(MessageDeliveryEvent::new( + let two = Box::new(MessageSendEvent::new( addr_2.clone(), addr_3.clone(), messages[1].clone(), sender.clone(), )); - let three = Box::new(MessageDeliveryEvent::new( + let three = Box::new(MessageSendEvent::new( addr_0.clone(), addr_1.clone(), messages[2].clone(), @@ -1385,7 +1532,7 @@ mod tests { .flush(Duration::from_millis(1000)) .await .unwrap(); - let records = simnet_handle().unwrap().records().await; + let records = simnet_handle().unwrap().close().await.unwrap(); eprintln!("Records: {:?}", records); // Close the channel simnet_handle().unwrap().close().await.unwrap(); @@ -1456,7 +1603,6 @@ edges: use crate::PortId; use crate::channel::Tx; - use crate::channel::sim::records; let proxy_addr = ChannelAddr::any(channel::ChannelTransport::Unix); start( @@ -1490,14 +1636,25 @@ edges: .flush(Duration::from_millis(1000)) .await .unwrap(); - let records = records().await.unwrap(); - assert!(records.as_ref().unwrap().len() == 1); - let expected_record = SimulatorEventRecord { - summary: format!("Sending message from {} to {}", src, dst), - start_at: 0, - end_at: 1, - }; - assert_eq!(records.unwrap().first().unwrap(), &expected_record); + let records = simnet_handle().unwrap().close().await.unwrap(); + let records = records + .as_array() + .unwrap() + .iter() + .filter(|r| r["ph"] != "M") + .collect::>(); + assert!(records.len() == 1); + let expected_record = serde_json::json!({ + "cat": "message", + "dur": 1000, + "flow_in": true, + "name": "recv unknown".to_string(), + "ph": "X", + "pid": 0, + "tid": 0, + "ts": 0, + }); + assert_eq!(**records.first().unwrap(), expected_record); } #[cfg(target_os = "linux")] @@ -1600,15 +1757,22 @@ edges: .flush(Duration::from_millis(1000)) .await .unwrap(); - let records = simnet_handle().unwrap().records().await; - let expected_record = SimulatorEventRecord { - summary: - "[mesh_0_worker[0].worker_0[0]] Torch Op: torch.ops.aten.ones.default(1, 2, a=2)" - .to_string(), - start_at: 0, - end_at: 100, - }; - assert!(records.as_ref().unwrap().len() == 1); - assert_eq!(records.unwrap().first().unwrap(), &expected_record); + let records = simnet_handle().unwrap().close().await.unwrap(); + let records = records + .as_array() + .unwrap() + .iter() + .filter(|r| r["ph"] != "M") + .collect::>(); + let expected_record = serde_json::json!({ + "cat": "compute", + "dur": 2000, + "name": "torch.ops.aten.ones.default", + "ph":"X", + "pid":0,"tid":0, + "ts": 0, + }); + assert!(records.len() == 1); + assert_eq!(*records.first().unwrap(), &expected_record); } } diff --git a/hyperactor_multiprocess/src/ping_pong.rs b/hyperactor_multiprocess/src/ping_pong.rs index 9c915aa4..d23f1b4a 100644 --- a/hyperactor_multiprocess/src/ping_pong.rs +++ b/hyperactor_multiprocess/src/ping_pong.rs @@ -91,14 +91,14 @@ edges: dst: local!1 metadata: latency: 1 - - src: local!1 + - src: local!2 dst: local!3 metadata: latency: 2 - src: local!3 - dst: local!1 + dst: local!2 metadata: - latency: 2 + latency: 1 "#; let simnet_config = NetworkConfig::from_yaml(simnet_config_yaml).unwrap(); sim::update_config(simnet_config).await.unwrap(); @@ -108,14 +108,14 @@ edges: // deliver a message to the ping actor with TTL - 2. This will continue until the TTL reaches 0. // The ping actor will then send a message to the done channel to indicate that the game is over. let (done_tx, done_rx) = sys_mailbox.open_once_port(); - let ping_pong_message = PingPongMessage(4, pong_actor_ref.clone(), done_tx.bind()); + let ping_pong_message = PingPongMessage(50, pong_actor_ref.clone(), done_tx.bind()); ping_actor_ref .send(&sys_mailbox, ping_pong_message) .unwrap(); assert!(done_rx.recv().await.unwrap()); - let records = sim::records().await.unwrap(); + let records = simnet::simnet_handle().unwrap().close().await.unwrap(); eprintln!( "records: {}", serde_json::to_string_pretty(&records).unwrap() diff --git a/monarch_extension/src/simulator_client.rs b/monarch_extension/src/simulator_client.rs index 74996005..113b2e6d 100644 --- a/monarch_extension/src/simulator_client.rs +++ b/monarch_extension/src/simulator_client.rs @@ -20,6 +20,7 @@ use hyperactor::simnet::OperationalMessage; use hyperactor::simnet::ProxyMessage; use hyperactor::simnet::SpawnMesh; use hyperactor::simnet::TrainingScriptState; +use hyperactor::simnet::simnet_handle; use monarch_hyperactor::runtime::signal_safe_block_on; use monarch_simulator_lib::bootstrap::bootstrap; use pyo3::exceptions::PyRuntimeError; @@ -125,6 +126,17 @@ impl SimulatorClient { fn set_training_script_state_waiting(&self) -> PyResult<()> { set_training_script_state(TrainingScriptState::Waiting, self.proxy_addr.clone()) } + + fn close(&self, py: Python) -> PyResult { + signal_safe_block_on(py, async move { + simnet_handle() + .unwrap() + .close() + .await + .map_err(|err| PyRuntimeError::new_err(err.to_string())) + .map(|s| s.to_string()) + })? + } } pub(crate) fn register_python_bindings(simulator_client_mod: &Bound<'_, PyModule>) -> PyResult<()> { diff --git a/python/monarch/_rust_bindings/monarch_extension/simulator_client.pyi b/python/monarch/_rust_bindings/monarch_extension/simulator_client.pyi index 9cea9d93..fda37a79 100644 --- a/python/monarch/_rust_bindings/monarch_extension/simulator_client.pyi +++ b/python/monarch/_rust_bindings/monarch_extension/simulator_client.pyi @@ -52,6 +52,12 @@ class SimulatorClient: """ ... + def close(self) -> str: + """ + Close the client and return perfetto traces of simulator events in JSON format. + """ + ... + def bootstrap_simulator_backend( system_addr: str, proxy_addr: str, world_size: int ) -> None: diff --git a/python/monarch/sim_mesh.py b/python/monarch/sim_mesh.py index 00258fda..5db83f43 100644 --- a/python/monarch/sim_mesh.py +++ b/python/monarch/sim_mesh.py @@ -6,7 +6,6 @@ # pyre-strict -import importlib.resources import logging import os import random @@ -14,7 +13,6 @@ import subprocess import tempfile import time -from pathlib import Path from typing import ( Callable, ContextManager as AbstractContextManager, @@ -53,6 +51,7 @@ from monarch.common.shape import NDSlice from monarch.controller.rust_backend.controller import RustController from monarch.rust_backend_mesh import MeshWorld +from monarch.simulator.trace import upload_trace logger: logging.Logger = logging.getLogger(__name__) @@ -113,6 +112,7 @@ def sim_mesh( bootstrap._simulator_client, f"mesh_{i}_worker", ) + dms.append(dm) return dms @@ -141,6 +141,31 @@ def __init__( super().__init__(client, processes, names, mesh_name) self.simulator_client: SimulatorClient = simulator_client + # restore Future.result and Future._set_result to their previous values + def create_exit( + dm: SimMesh[T], + ) -> Callable[[Optional[RemoteException | DeviceException | Exception]], None]: + def exit( + error: Optional[RemoteException | DeviceException | Exception] = None, + ) -> None: + dm.client.shutdown(True, error) + # pyre-ignore + Future.result = OriginalFutureWrapper.result + Future._set_result = OriginalFutureWrapper._set_result + + records = dm.simulator_client.close() + trace_file_path = tempfile.mkdtemp(prefix="simulator_traces") + file_name = f"{trace_file_path}/trace.json" + with open(file_name, "w") as f: + f.write(records) + upload_trace(file_name) + + return exit + + self.exit: Callable[ + [Optional[RemoteException | DeviceException | Exception]], None + ] = create_exit(self) + # monkey patch Future.result and Future._set_result to hook into set_training_script_state_{running,waiting} def activate(self) -> AbstractContextManager[DeviceMesh]: def sim_result(fut: Future[T], timeout: float | None = None) -> T: @@ -157,16 +182,6 @@ def sim_set_result(fut: Future[T], result: T) -> None: return super().activate() - # restore Future.result and Future._set_result to their previous values - def exit( - self, - error: Optional[RemoteException | DeviceException | Exception] = None, - ) -> None: - self.client.shutdown(True, error) - # pyre-ignore - Future.result = OriginalFutureWrapper._result - Future._set_result = OriginalFutureWrapper._set_result - def _random_id(length: int = 14) -> str: """ @@ -201,9 +216,11 @@ def __init__( proxy_addr = proxy_addr or f"unix!@{_random_id()}-proxy" self.bootstrap_addr: str = f"sim!unix!@system,{proxy_addr}" - self.client_listen_addr: str = f"sim!unix!@client,{proxy_addr}" + + client_proxy_addr = f"unix!@{_random_id()}-proxy" + self.client_listen_addr: str = f"sim!unix!@client,{client_proxy_addr}" self.client_bootstrap_addr: str = ( - f"sim!unix!@client,{proxy_addr},unix!@system,{proxy_addr}" + f"sim!unix!@client,{client_proxy_addr},unix!@system,{proxy_addr}" ) bootstrap_simulator_backend(self.bootstrap_addr, proxy_addr, world_size)