Skip to content

psudo paft #210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 51 additions & 4 deletions hyperactor_extension/src/telemetry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,42 @@

#![allow(unsafe_op_in_unsafe_fn)]

use std::cell::Cell;

use pyo3::prelude::*;
use tracing::span::EnteredSpan;
// Thread local to store the current span
thread_local! {
static ACTIVE_ACTOR_SPAN: Cell<Option<EnteredSpan>> = const { Cell::new(None) };
}

/// Enter the span stored in the thread local
#[pyfunction]
pub fn enter_span(module_name: String, method_name: String, actor_id: String) -> PyResult<()> {
let mut maybe_span = ACTIVE_ACTOR_SPAN.take();
if maybe_span.is_none() {
maybe_span = Some(
tracing::info_span!(
"py_actor_method",
name = method_name,
target = module_name,
actor_id = actor_id
)
.entered(),
);
}
ACTIVE_ACTOR_SPAN.set(maybe_span);
Ok(())
}

/// Exit the span stored in the thread local
#[pyfunction]
pub fn exit_span() -> PyResult<()> {
ACTIVE_ACTOR_SPAN.replace(None);
Ok(())
}

/// Log a message with the given metadata
/// Log a message with the given metaata
#[pyfunction]
pub fn forward_to_tracing(message: &str, file: &str, lineno: i64, level: i32) {
// Map level number to level name
Expand All @@ -23,15 +56,29 @@ pub fn forward_to_tracing(message: &str, file: &str, lineno: i64, level: i32) {
}
}

use pyo3::Bound;
use pyo3::types::PyModule;

pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> {
// Register the forward_to_tracing function
let f = wrap_pyfunction!(forward_to_tracing, module)?;
f.setattr(
"__module__",
"monarch._rust_bindings.hyperactor_extension.telemetry",
)?;
module.add_function(f)?;

// Register the span-related functions
let enter_span_fn = wrap_pyfunction!(enter_span, module)?;
enter_span_fn.setattr(
"__module__",
"monarch._rust_bindings.hyperactor_extension.telemetry",
)?;
module.add_function(enter_span_fn)?;

let exit_span_fn = wrap_pyfunction!(exit_span, module)?;
exit_span_fn.setattr(
"__module__",
"monarch._rust_bindings.hyperactor_extension.telemetry",
)?;
module.add_function(exit_span_fn)?;

Ok(())
}
1 change: 1 addition & 0 deletions monarch_hyperactor/src/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ use serde::Serialize;
use serde_bytes::ByteBuf;
use tokio::sync::Mutex;
use tokio::sync::oneshot;
use tracing::span::Id;

use crate::mailbox::PyMailbox;
use crate::proc::InstanceWrapper;
Expand Down
43 changes: 40 additions & 3 deletions python/monarch/_rust_bindings/hyperactor_extension/telemetry.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,49 @@

def forward_to_tracing(message: str, file: str, lineno: int, level: int) -> None:
"""
Log a message with the given metadata.
Log a message with the given metadata using the tracing system.

This function forwards Python log messages to the Rust tracing system,
preserving the original source location and log level.

Args:
- message (str): The log message.
- message (str): The log message content.
- file (str): The file where the log message originated.
- lineno (int): The line number where the log message originated.
- level (int): The log level (10 for debug, 20 for info, 30 for warn, 40 for error).
- level (int): The log level:
- 10: DEBUG
- 20: INFO
- 30: WARN
- 40: ERROR
- other values default to INFO
"""
...

def enter_span(module_name: str, method_name: str, actor_id: str) -> None:
"""
Enter a tracing span for a Python actor method.

Creates and enters a new tracing span for the current thread that tracks
execution of a Python actor method. The span is stored in thread-local
storage and will be active until exit_span() is called.

If a span is already active for the current thread, this function will
preserve that span and not create a new one.

Args:
- module_name (str): The name of the module containing the actor (used as the target).
- method_name (str): The name of the method being called (used as the span name).
- actor_id (str): The ID of the actor instance (included as a field in the span).
"""
...

def exit_span() -> None:
"""
Exit the current tracing span for a Python actor method.

Exits and drops the tracing span that was previously created by enter_span().
This should be called when the actor method execution is complete.

If no span is currently active for this thread, this function has no effect.
"""
...
35 changes: 32 additions & 3 deletions python/monarch/actor_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import asyncio
import collections
import contextvars
import functools
import inspect

import itertools
Expand Down Expand Up @@ -38,6 +39,7 @@

import monarch
from monarch import ActorFuture as Future
from monarch._rust_bindings.hyperactor_extension.telemetry import enter_span, exit_span

from monarch._rust_bindings.monarch_hyperactor.actor import PanicFlag, PythonMessage
from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh
Expand All @@ -49,6 +51,7 @@
)
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
from monarch._rust_bindings.monarch_hyperactor.shape import Point as HyPoint, Shape

from monarch.common.pickle_flatten import flatten, unflatten
from monarch.common.shape import MeshTrait, NDSlice

Expand Down Expand Up @@ -492,13 +495,29 @@ def handle_cast(
return None
else:
the_method = getattr(self.instance, message.method)._method
result = the_method(self.instance, *args, **kwargs)

if not inspect.iscoroutinefunction(the_method):
enter_span(
the_method.__module__, message.method, str(ctx.mailbox.actor_id)
)
result = the_method(self.instance, *args, **kwargs)
exit_span()
if port is not None:
port.send("result", result)
return None

return self.run_async(ctx, self.run_task(port, result, panic_flag))
async def instrumented():
enter_span(
the_method.__module__, message.method, str(ctx.mailbox.actor_id)
)
result = await the_method(self.instance, *args, **kwargs)
exit_span()
return result

return self.run_async(
ctx,
self.run_task(port, instrumented(), panic_flag),
)
except Exception as e:
traceback.print_exc()
s = ActorError(e)
Expand All @@ -510,7 +529,11 @@ def handle_cast(
else:
raise s from None

async def run_async(self, ctx, coroutine):
async def run_async(
self,
ctx: MonarchContext,
coroutine: Coroutine[Any, None, Any],
) -> None:
_context.set(ctx)
if self.complete_task is None:
self.complete_task = asyncio.create_task(self._complete())
Expand Down Expand Up @@ -564,6 +587,12 @@ def _unpickle(data: bytes, mailbox: Mailbox) -> Any:


class Actor(MeshTrait):
@functools.cached_property
def logger(cls) -> logging.Logger:
lgr = logging.getLogger(cls.__class__.__name__)
lgr.setLevel(logging.DEBUG)
return lgr

@property
def _ndslice(self) -> NDSlice:
raise NotImplementedError(
Expand Down
23 changes: 2 additions & 21 deletions python/monarch/bootstrap_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,35 +30,16 @@ def invoke_main():
# behavior of std out as if it were a terminal.
sys.stdout.reconfigure(line_buffering=True)
global bootstrap_main
from monarch._rust_bindings.hyperactor_extension.telemetry import ( # @manual=//monarch/monarch_extension:monarch_extension # @manual=//monarch/monarch_extension:monarch_extension
forward_to_tracing,
)

# TODO: figure out what from worker_main.py we should reproduce here.

class TracingForwarder(logging.Handler):
def emit(self, record: logging.LogRecord) -> None:
try:
forward_to_tracing(
record.getMessage(),
record.filename or "",
record.lineno or 0,
record.levelno,
)
except AttributeError:
forward_to_tracing(
record.__str__(),
record.filename or "",
record.lineno or 0,
record.levelno,
)
from monarch.telemetry import TracingForwarder

if os.environ.get("MONARCH_ERROR_DURING_BOOTSTRAP_FOR_TESTING") == "1":
raise RuntimeError("Error during bootstrap for testing")

# forward logs to rust tracing. Defaults to on.
if os.environ.get("MONARCH_PYTHON_LOG_TRACING", "1") == "1":
logging.root.addHandler(TracingForwarder())
logging.root.addHandler(TracingForwarder(level=logging.DEBUG))

try:
with (
Expand Down
32 changes: 32 additions & 0 deletions python/monarch/telemetry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict


import logging

from monarch._rust_bindings.hyperactor_extension.telemetry import ( # @manual=//monarch/monarch_extension:monarch_extension
forward_to_tracing,
)


class TracingForwarder(logging.Handler):
def emit(self, record: logging.LogRecord) -> None:
try:
forward_to_tracing(
record.getMessage(),
record.filename or "",
record.lineno or 0,
record.levelno,
)
except AttributeError:
forward_to_tracing(
record.__str__(),
record.filename or "",
record.lineno or 0,
record.levelno,
)