diff --git a/hyperactor_extension/src/telemetry.rs b/hyperactor_extension/src/telemetry.rs index f27f09f6..dc695b8a 100644 --- a/hyperactor_extension/src/telemetry.rs +++ b/hyperactor_extension/src/telemetry.rs @@ -8,9 +8,42 @@ #![allow(unsafe_op_in_unsafe_fn)] +use std::cell::Cell; + use pyo3::prelude::*; +use tracing::span::EnteredSpan; +// Thread local to store the current span +thread_local! { + static ACTIVE_ACTOR_SPAN: Cell> = const { Cell::new(None) }; +} + +/// Enter the span stored in the thread local +#[pyfunction] +pub fn enter_span(module_name: String, method_name: String, actor_id: String) -> PyResult<()> { + let mut maybe_span = ACTIVE_ACTOR_SPAN.take(); + if maybe_span.is_none() { + maybe_span = Some( + tracing::info_span!( + "py_actor_method", + name = method_name, + target = module_name, + actor_id = actor_id + ) + .entered(), + ); + } + ACTIVE_ACTOR_SPAN.set(maybe_span); + Ok(()) +} + +/// Exit the span stored in the thread local +#[pyfunction] +pub fn exit_span() -> PyResult<()> { + ACTIVE_ACTOR_SPAN.replace(None); + Ok(()) +} -/// Log a message with the given metadata +/// Log a message with the given metaata #[pyfunction] pub fn forward_to_tracing(message: &str, file: &str, lineno: i64, level: i32) { // Map level number to level name @@ -23,15 +56,29 @@ pub fn forward_to_tracing(message: &str, file: &str, lineno: i64, level: i32) { } } -use pyo3::Bound; -use pyo3::types::PyModule; - pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> { + // Register the forward_to_tracing function let f = wrap_pyfunction!(forward_to_tracing, module)?; f.setattr( "__module__", "monarch._rust_bindings.hyperactor_extension.telemetry", )?; module.add_function(f)?; + + // Register the span-related functions + let enter_span_fn = wrap_pyfunction!(enter_span, module)?; + enter_span_fn.setattr( + "__module__", + "monarch._rust_bindings.hyperactor_extension.telemetry", + )?; + module.add_function(enter_span_fn)?; + + let exit_span_fn = wrap_pyfunction!(exit_span, module)?; + exit_span_fn.setattr( + "__module__", + "monarch._rust_bindings.hyperactor_extension.telemetry", + )?; + module.add_function(exit_span_fn)?; + Ok(()) } diff --git a/monarch_hyperactor/src/actor.rs b/monarch_hyperactor/src/actor.rs index 7ac2137b..f26925be 100644 --- a/monarch_hyperactor/src/actor.rs +++ b/monarch_hyperactor/src/actor.rs @@ -37,6 +37,7 @@ use serde::Serialize; use serde_bytes::ByteBuf; use tokio::sync::Mutex; use tokio::sync::oneshot; +use tracing::span::Id; use crate::mailbox::PyMailbox; use crate::proc::InstanceWrapper; diff --git a/python/monarch/_rust_bindings/hyperactor_extension/telemetry.pyi b/python/monarch/_rust_bindings/hyperactor_extension/telemetry.pyi index f7201df1..f04551ea 100644 --- a/python/monarch/_rust_bindings/hyperactor_extension/telemetry.pyi +++ b/python/monarch/_rust_bindings/hyperactor_extension/telemetry.pyi @@ -6,12 +6,49 @@ def forward_to_tracing(message: str, file: str, lineno: int, level: int) -> None: """ - Log a message with the given metadata. + Log a message with the given metadata using the tracing system. + + This function forwards Python log messages to the Rust tracing system, + preserving the original source location and log level. Args: - - message (str): The log message. + - message (str): The log message content. - file (str): The file where the log message originated. - lineno (int): The line number where the log message originated. - - level (int): The log level (10 for debug, 20 for info, 30 for warn, 40 for error). + - level (int): The log level: + - 10: DEBUG + - 20: INFO + - 30: WARN + - 40: ERROR + - other values default to INFO + """ + ... + +def enter_span(module_name: str, method_name: str, actor_id: str) -> None: + """ + Enter a tracing span for a Python actor method. + + Creates and enters a new tracing span for the current thread that tracks + execution of a Python actor method. The span is stored in thread-local + storage and will be active until exit_span() is called. + + If a span is already active for the current thread, this function will + preserve that span and not create a new one. + + Args: + - module_name (str): The name of the module containing the actor (used as the target). + - method_name (str): The name of the method being called (used as the span name). + - actor_id (str): The ID of the actor instance (included as a field in the span). + """ + ... + +def exit_span() -> None: + """ + Exit the current tracing span for a Python actor method. + + Exits and drops the tracing span that was previously created by enter_span(). + This should be called when the actor method execution is complete. + + If no span is currently active for this thread, this function has no effect. """ ... diff --git a/python/monarch/actor_mesh.py b/python/monarch/actor_mesh.py index cb92d0f6..d94452db 100644 --- a/python/monarch/actor_mesh.py +++ b/python/monarch/actor_mesh.py @@ -7,6 +7,7 @@ import asyncio import collections import contextvars +import functools import inspect import itertools @@ -38,6 +39,7 @@ import monarch from monarch import ActorFuture as Future +from monarch._rust_bindings.hyperactor_extension.telemetry import enter_span, exit_span from monarch._rust_bindings.monarch_hyperactor.actor import PanicFlag, PythonMessage from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh @@ -49,6 +51,7 @@ ) from monarch._rust_bindings.monarch_hyperactor.proc import ActorId from monarch._rust_bindings.monarch_hyperactor.shape import Point as HyPoint, Shape + from monarch.common.pickle_flatten import flatten, unflatten from monarch.common.shape import MeshTrait, NDSlice @@ -492,13 +495,29 @@ def handle_cast( return None else: the_method = getattr(self.instance, message.method)._method - result = the_method(self.instance, *args, **kwargs) + if not inspect.iscoroutinefunction(the_method): + enter_span( + the_method.__module__, message.method, str(ctx.mailbox.actor_id) + ) + result = the_method(self.instance, *args, **kwargs) + exit_span() if port is not None: port.send("result", result) return None - return self.run_async(ctx, self.run_task(port, result, panic_flag)) + async def instrumented(): + enter_span( + the_method.__module__, message.method, str(ctx.mailbox.actor_id) + ) + result = await the_method(self.instance, *args, **kwargs) + exit_span() + return result + + return self.run_async( + ctx, + self.run_task(port, instrumented(), panic_flag), + ) except Exception as e: traceback.print_exc() s = ActorError(e) @@ -510,7 +529,11 @@ def handle_cast( else: raise s from None - async def run_async(self, ctx, coroutine): + async def run_async( + self, + ctx: MonarchContext, + coroutine: Coroutine[Any, None, Any], + ) -> None: _context.set(ctx) if self.complete_task is None: self.complete_task = asyncio.create_task(self._complete()) @@ -564,6 +587,12 @@ def _unpickle(data: bytes, mailbox: Mailbox) -> Any: class Actor(MeshTrait): + @functools.cached_property + def logger(cls) -> logging.Logger: + lgr = logging.getLogger(cls.__class__.__name__) + lgr.setLevel(logging.DEBUG) + return lgr + @property def _ndslice(self) -> NDSlice: raise NotImplementedError( diff --git a/python/monarch/bootstrap_main.py b/python/monarch/bootstrap_main.py index 594ef895..77befe9b 100644 --- a/python/monarch/bootstrap_main.py +++ b/python/monarch/bootstrap_main.py @@ -30,35 +30,16 @@ def invoke_main(): # behavior of std out as if it were a terminal. sys.stdout.reconfigure(line_buffering=True) global bootstrap_main - from monarch._rust_bindings.hyperactor_extension.telemetry import ( # @manual=//monarch/monarch_extension:monarch_extension # @manual=//monarch/monarch_extension:monarch_extension - forward_to_tracing, - ) # TODO: figure out what from worker_main.py we should reproduce here. - - class TracingForwarder(logging.Handler): - def emit(self, record: logging.LogRecord) -> None: - try: - forward_to_tracing( - record.getMessage(), - record.filename or "", - record.lineno or 0, - record.levelno, - ) - except AttributeError: - forward_to_tracing( - record.__str__(), - record.filename or "", - record.lineno or 0, - record.levelno, - ) + from monarch.telemetry import TracingForwarder if os.environ.get("MONARCH_ERROR_DURING_BOOTSTRAP_FOR_TESTING") == "1": raise RuntimeError("Error during bootstrap for testing") # forward logs to rust tracing. Defaults to on. if os.environ.get("MONARCH_PYTHON_LOG_TRACING", "1") == "1": - logging.root.addHandler(TracingForwarder()) + logging.root.addHandler(TracingForwarder(level=logging.DEBUG)) try: with ( diff --git a/python/monarch/telemetry.py b/python/monarch/telemetry.py new file mode 100644 index 00000000..84865498 --- /dev/null +++ b/python/monarch/telemetry.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +import logging + +from monarch._rust_bindings.hyperactor_extension.telemetry import ( # @manual=//monarch/monarch_extension:monarch_extension + forward_to_tracing, +) + + +class TracingForwarder(logging.Handler): + def emit(self, record: logging.LogRecord) -> None: + try: + forward_to_tracing( + record.getMessage(), + record.filename or "", + record.lineno or 0, + record.levelno, + ) + except AttributeError: + forward_to_tracing( + record.__str__(), + record.filename or "", + record.lineno or 0, + record.levelno, + )