Skip to content

Fix python logs in asyncio code not showing up in scuba #155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 51 additions & 4 deletions hyperactor_extension/src/telemetry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,42 @@

#![allow(unsafe_op_in_unsafe_fn)]

use std::cell::Cell;

use pyo3::prelude::*;
use tracing::span::EnteredSpan;
// Thread local to store the current span
thread_local! {
static ACTIVE_ACTOR_SPAN: Cell<Option<EnteredSpan>> = const { Cell::new(None) };
}

/// Enter the span stored in the thread local
#[pyfunction]
pub fn enter_span(module_name: String, method_name: String, actor_id: String) -> PyResult<()> {
let mut maybe_span = ACTIVE_ACTOR_SPAN.take();
if maybe_span.is_none() {
maybe_span = Some(
tracing::info_span!(
"py_actor_method",
name = method_name,
target = module_name,
actor_id = actor_id
)
.entered(),
);
}
ACTIVE_ACTOR_SPAN.set(maybe_span);
Ok(())
}

/// Exit the span stored in the thread local
#[pyfunction]
pub fn exit_span() -> PyResult<()> {
ACTIVE_ACTOR_SPAN.replace(None);
Ok(())
}

/// Log a message with the given metadata
/// Log a message with the given metaata
#[pyfunction]
pub fn forward_to_tracing(message: &str, file: &str, lineno: i64, level: i32) {
// Map level number to level name
Expand All @@ -23,15 +56,29 @@ pub fn forward_to_tracing(message: &str, file: &str, lineno: i64, level: i32) {
}
}

use pyo3::Bound;
use pyo3::types::PyModule;

pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> {
// Register the forward_to_tracing function
let f = wrap_pyfunction!(forward_to_tracing, module)?;
f.setattr(
"__module__",
"monarch._rust_bindings.hyperactor_extension.telemetry",
)?;
module.add_function(f)?;

// Register the span-related functions
let enter_span_fn = wrap_pyfunction!(enter_span, module)?;
enter_span_fn.setattr(
"__module__",
"monarch._rust_bindings.hyperactor_extension.telemetry",
)?;
module.add_function(enter_span_fn)?;

let exit_span_fn = wrap_pyfunction!(exit_span, module)?;
exit_span_fn.setattr(
"__module__",
"monarch._rust_bindings.hyperactor_extension.telemetry",
)?;
module.add_function(exit_span_fn)?;

Ok(())
}
1 change: 1 addition & 0 deletions monarch_hyperactor/src/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ use serde::Serialize;
use serde_bytes::ByteBuf;
use tokio::sync::Mutex;
use tokio::sync::oneshot;
use tracing::span::Id;

use crate::mailbox::PyMailbox;
use crate::proc::InstanceWrapper;
Expand Down
43 changes: 40 additions & 3 deletions python/monarch/_rust_bindings/hyperactor_extension/telemetry.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,49 @@

def forward_to_tracing(message: str, file: str, lineno: int, level: int) -> None:
"""
Log a message with the given metadata.
Log a message with the given metadata using the tracing system.

This function forwards Python log messages to the Rust tracing system,
preserving the original source location and log level.

Args:
- message (str): The log message.
- message (str): The log message content.
- file (str): The file where the log message originated.
- lineno (int): The line number where the log message originated.
- level (int): The log level (10 for debug, 20 for info, 30 for warn, 40 for error).
- level (int): The log level:
- 10: DEBUG
- 20: INFO
- 30: WARN
- 40: ERROR
- other values default to INFO
"""
...

def enter_span(module_name: str, method_name: str, actor_id: str) -> None:
"""
Enter a tracing span for a Python actor method.

Creates and enters a new tracing span for the current thread that tracks
execution of a Python actor method. The span is stored in thread-local
storage and will be active until exit_span() is called.

If a span is already active for the current thread, this function will
preserve that span and not create a new one.

Args:
- module_name (str): The name of the module containing the actor (used as the target).
- method_name (str): The name of the method being called (used as the span name).
- actor_id (str): The ID of the actor instance (included as a field in the span).
"""
...

def exit_span() -> None:
"""
Exit the current tracing span for a Python actor method.

Exits and drops the tracing span that was previously created by enter_span().
This should be called when the actor method execution is complete.

If no span is currently active for this thread, this function has no effect.
"""
...
35 changes: 32 additions & 3 deletions python/monarch/actor_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import asyncio
import collections
import contextvars
import functools
import inspect

import itertools
Expand Down Expand Up @@ -38,6 +39,7 @@

import monarch
from monarch import ActorFuture as Future
from monarch._rust_bindings.hyperactor_extension.telemetry import enter_span, exit_span

from monarch._rust_bindings.monarch_hyperactor.actor import PanicFlag, PythonMessage
from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh
Expand All @@ -49,6 +51,7 @@
)
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
from monarch._rust_bindings.monarch_hyperactor.shape import Point as HyPoint, Shape

from monarch.common.pickle_flatten import flatten, unflatten
from monarch.common.shape import MeshTrait, NDSlice

Expand Down Expand Up @@ -492,13 +495,29 @@ def handle_cast(
return None
else:
the_method = getattr(self.instance, message.method)._method
result = the_method(self.instance, *args, **kwargs)

if not inspect.iscoroutinefunction(the_method):
enter_span(
the_method.__module__, message.method, str(ctx.mailbox.actor_id)
)
result = the_method(self.instance, *args, **kwargs)
exit_span()
if port is not None:
port.send("result", result)
return None

return self.run_async(ctx, self.run_task(port, result, panic_flag))
async def instrumented():
enter_span(
the_method.__module__, message.method, str(ctx.mailbox.actor_id)
)
result = await the_method(self.instance, *args, **kwargs)
exit_span()
return result

return self.run_async(
ctx,
self.run_task(port, instrumented(), panic_flag),
)
except Exception as e:
traceback.print_exc()
s = ActorError(e)
Expand All @@ -510,7 +529,11 @@ def handle_cast(
else:
raise s from None

async def run_async(self, ctx, coroutine):
async def run_async(
self,
ctx: MonarchContext,
coroutine: Coroutine[Any, None, Any],
) -> None:
_context.set(ctx)
if self.complete_task is None:
self.complete_task = asyncio.create_task(self._complete())
Expand Down Expand Up @@ -564,6 +587,12 @@ def _unpickle(data: bytes, mailbox: Mailbox) -> Any:


class Actor(MeshTrait):
@functools.cached_property
def logger(cls) -> logging.Logger:
lgr = logging.getLogger(cls.__class__.__name__)
lgr.setLevel(logging.DEBUG)
return lgr

@property
def _ndslice(self) -> NDSlice:
raise NotImplementedError(
Expand Down
2 changes: 1 addition & 1 deletion python/monarch/bootstrap_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def emit(self, record: logging.LogRecord) -> None:

# forward logs to rust tracing. Defaults to on.
if os.environ.get("MONARCH_PYTHON_LOG_TRACING", "1") == "1":
logging.root.addHandler(TracingForwarder())
logging.root.addHandler(TracingForwarder(level=logging.DEBUG))

try:
with (
Expand Down