diff --git a/mlua-sys/src/luau/luarequire.rs b/mlua-sys/src/luau/luarequire.rs index 7225cb9b..1a1974d6 100644 --- a/mlua-sys/src/luau/luarequire.rs +++ b/mlua-sys/src/luau/luarequire.rs @@ -30,7 +30,7 @@ pub struct luarequire_Configuration { unsafe extern "C" fn(L: *mut lua_State, ctx: *mut c_void, requirer_chunkname: *const c_char) -> bool, // Resets the internal state to point at the requirer module. - pub reset: unsafe extern "C" fn( + pub reset: unsafe extern "C-unwind" fn( L: *mut lua_State, ctx: *mut c_void, requirer_chunkname: *const c_char, @@ -39,15 +39,15 @@ pub struct luarequire_Configuration { // Resets the internal state to point at an aliased module, given its exact path from a configuration // file. This function is only called when an alias's path cannot be resolved relative to its // configuration file. - pub jump_to_alias: unsafe extern "C" fn( + pub jump_to_alias: unsafe extern "C-unwind" fn( L: *mut lua_State, ctx: *mut c_void, path: *const c_char, ) -> luarequire_NavigateResult, // Navigates through the context by making mutations to the internal state. - pub to_parent: unsafe extern "C" fn(L: *mut lua_State, ctx: *mut c_void) -> luarequire_NavigateResult, - pub to_child: unsafe extern "C" fn( + pub to_parent: unsafe extern "C-unwind" fn(L: *mut lua_State, ctx: *mut c_void) -> luarequire_NavigateResult, + pub to_child: unsafe extern "C-unwind" fn( L: *mut lua_State, ctx: *mut c_void, name: *const c_char, diff --git a/src/luau/require.rs b/src/luau/require.rs index 152114de..92e07666 100644 --- a/src/luau/require.rs +++ b/src/luau/require.rs @@ -6,17 +6,19 @@ use std::os::raw::{c_char, c_int, c_void}; use std::path::{Component, Path, PathBuf}; use std::result::Result as StdResult; use std::{env, fmt, fs, mem, ptr}; - -use crate::error::Result; +use crate::error::{Result, Error}; use crate::function::Function; use crate::state::{callback_error_ext, Lua}; use crate::table::Table; use crate::types::MaybeSend; +use crate::traits::IntoLua; +use crate::state::RawLua; /// An error that can occur during navigation in the Luau `require` system. pub enum NavigateError { Ambiguous, NotFound, + Error(Error) } #[cfg(feature = "luau")] @@ -31,6 +33,7 @@ impl IntoNavigateResult for StdResult<(), NavigateError> { Ok(()) => ffi::luarequire_NavigateResult::Success, Err(NavigateError::Ambiguous) => ffi::luarequire_NavigateResult::Ambiguous, Err(NavigateError::NotFound) => ffi::luarequire_NavigateResult::NotFound, + Err(NavigateError::Error(_)) => unreachable!() } } } @@ -320,42 +323,70 @@ pub(super) unsafe extern "C" fn init_config(config: *mut ffi::luarequire_Configu this.is_require_allowed(&chunk_name) } - unsafe extern "C" fn reset( - _state: *mut ffi::lua_State, + unsafe extern "C-unwind" fn reset( + state: *mut ffi::lua_State, ctx: *mut c_void, requirer_chunkname: *const c_char, ) -> ffi::luarequire_NavigateResult { let this = &*(ctx as *const Box); let chunk_name = CStr::from_ptr(requirer_chunkname).to_string_lossy(); - this.reset(&chunk_name).into_nav_result() + match this.reset(&chunk_name) { + Err(NavigateError::Error(err)) => { + let raw_lua = RawLua::init_from_ptr(state, false); + err.push_into_stack(&raw_lua.lock()).expect("mlua internal: failed to push error to stack"); + ffi::lua_error(state); + }, + error => error.into_nav_result() + } } - unsafe extern "C" fn jump_to_alias( - _state: *mut ffi::lua_State, + unsafe extern "C-unwind" fn jump_to_alias( + state: *mut ffi::lua_State, ctx: *mut c_void, path: *const c_char, ) -> ffi::luarequire_NavigateResult { let this = &*(ctx as *const Box); let path = CStr::from_ptr(path).to_string_lossy(); - this.jump_to_alias(&path).into_nav_result() + match this.jump_to_alias(&path) { + Err(NavigateError::Error(err)) => { + let raw_lua = RawLua::init_from_ptr(state, false); + err.push_into_stack(&raw_lua.lock()).expect("mlua internal: failed to push error to stack"); + ffi::lua_error(state); + }, + error => error.into_nav_result() + } } - unsafe extern "C" fn to_parent( - _state: *mut ffi::lua_State, + unsafe extern "C-unwind" fn to_parent( + state: *mut ffi::lua_State, ctx: *mut c_void, ) -> ffi::luarequire_NavigateResult { let this = &*(ctx as *const Box); - this.to_parent().into_nav_result() + match this.to_parent() { + Err(NavigateError::Error(err)) => { + let raw_lua = RawLua::init_from_ptr(state, false); + err.push_into_stack(&raw_lua.lock()).expect("mlua internal: failed to push error to stack"); + ffi::lua_error(state); + }, + error => error.into_nav_result() + } } - unsafe extern "C" fn to_child( - _state: *mut ffi::lua_State, + unsafe extern "C-unwind" fn to_child( + state: *mut ffi::lua_State, ctx: *mut c_void, name: *const c_char, ) -> ffi::luarequire_NavigateResult { let this = &*(ctx as *const Box); let name = CStr::from_ptr(name).to_string_lossy(); - this.to_child(&name).into_nav_result() + match this.to_child(&name) { + Err(NavigateError::Error(err)) => { + let raw_lua = RawLua::init_from_ptr(state, false); + err.push_into_stack(&raw_lua.lock()).expect("mlua internal: failed to push error to stack"); + ffi::lua_error(state); + }, + error => error.into_nav_result() + } } unsafe extern "C" fn is_module_present(_state: *mut ffi::lua_State, ctx: *mut c_void) -> bool { diff --git a/src/state/raw.rs b/src/state/raw.rs index cc4ce1cf..3669ce85 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -179,7 +179,7 @@ impl RawLua { rawlua } - pub(super) unsafe fn init_from_ptr(state: *mut ffi::lua_State, owned: bool) -> XRc> { + pub(crate) unsafe fn init_from_ptr(state: *mut ffi::lua_State, owned: bool) -> XRc> { assert!(!state.is_null(), "Lua state is NULL"); if let Some(lua) = Self::try_from_ptr(state) { return lua; diff --git a/tests/luau/require.rs b/tests/luau/require.rs index a38b0cd6..30d77240 100644 --- a/tests/luau/require.rs +++ b/tests/luau/require.rs @@ -1,4 +1,10 @@ -use mlua::{IntoLua, Lua, Result, Value}; +use mlua::{IntoLua, Lua, Result, Value, NavigateError, Error, prelude::LuaRequire}; +use std::io::Result as IoResult; +use std::result::Result as StdResult; +use std::{env, fs}; +use std::path::{Component, Path, PathBuf}; +use std::cell::RefCell; +use std::collections::VecDeque; fn run_require(lua: &Lua, path: impl IntoLua) -> Result { lua.load(r#"return require(...)"#).call(path) @@ -141,3 +147,245 @@ async fn test_async_require() -> Result<()> { .exec_async() .await } + +#[test] +fn test_require_custom_error() { + let lua = Lua::new(); + lua.globals().set("require", lua.create_require_function(TextRequirer::new(true)).unwrap()).unwrap(); + + let memusage = lua.used_memory(); + + let res = run_require(&lua, "@failed/failure"); + assert!(res.is_err()); + println!("{}", res.clone().unwrap_err().to_string()); + assert!((res.unwrap_err().to_string()).contains("custom error")); + + // ensure repeat calls do not lead to error + let res = run_require(&lua, "@failed/failure"); + assert!(res.is_err()); + println!("{}", res.clone().unwrap_err().to_string()); + assert!((res.unwrap_err().to_string()).contains("custom error")); + + // Ensure valid stack after end of tests + let stack_count: i32 = unsafe { + lua.exec_raw((), |state| { + let n = mlua::ffi::lua_gettop(state); + mlua::ffi::lua_pushinteger(state, n.into()); + }).unwrap() + }; + + assert_eq!(stack_count, 0); + + lua.gc_collect().unwrap(); + lua.gc_collect().unwrap(); + + assert_eq!(memusage, lua.used_memory()); +} + +/// Simple test require trait to test custom errors +#[derive(Default)] +struct TextRequirer { + abs_path: RefCell, + rel_path: RefCell, + module_path: RefCell, + error_on_reset: bool, +} + +impl TextRequirer { + pub fn new(error_on_reset: bool) -> Self { + Self { + error_on_reset, + ..Default::default() + } + } + + fn normalize_chunk_name(chunk_name: &str) -> &str { + if let Some((path, line)) = chunk_name.split_once(':') { + if line.parse::().is_ok() { + return path; + } + } + chunk_name + } + + // Normalizes the path by removing unnecessary components + fn normalize_path(path: &Path) -> PathBuf { + let mut components = VecDeque::new(); + + for comp in path.components() { + match comp { + Component::Prefix(..) | Component::RootDir => { + components.push_back(comp); + } + Component::CurDir => {} + Component::ParentDir => { + if matches!(components.back(), None | Some(Component::ParentDir)) { + components.push_back(Component::ParentDir); + } else if matches!(components.back(), Some(Component::Normal(..))) { + components.pop_back(); + } + } + Component::Normal(..) => components.push_back(comp), + } + } + + if matches!(components.front(), None | Some(Component::Normal(..))) { + components.push_front(Component::CurDir); + } + + // Join the components back together + components.into_iter().collect() + } + + fn find_module_path(path: &Path) -> StdResult { + let mut found_path = None; + + let current_ext = (path.extension().and_then(|s| s.to_str())) + .map(|s| format!("{s}.")) + .unwrap_or_default(); + for ext in ["luau", "lua"] { + let candidate = path.with_extension(format!("{current_ext}{ext}")); + if candidate.is_file() { + if found_path.is_some() { + return Err(NavigateError::Ambiguous); + } + found_path = Some(candidate); + } + } + if path.is_dir() { + if found_path.is_some() { + return Err(NavigateError::Ambiguous); + } + + for component in ["init.luau", "init.lua"] { + let candidate = path.join(component); + if candidate.is_file() { + if found_path.is_some() { + return Err(NavigateError::Ambiguous); + } + found_path = Some(candidate); + } + } + + if found_path.is_none() { + found_path = Some(PathBuf::new()); + } + } + + found_path.ok_or(NavigateError::NotFound) + } +} + +impl LuaRequire for TextRequirer { + fn is_require_allowed(&self, chunk_name: &str) -> bool { + chunk_name.starts_with('@') + } + + fn reset(&self, chunk_name: &str) -> StdResult<(), NavigateError> { + if self.error_on_reset { + return Err(NavigateError::Error(Error::runtime("custom error".to_string()))); + } + + if !chunk_name.starts_with('@') { + return Err(NavigateError::NotFound); + } + let chunk_name = &Self::normalize_chunk_name(chunk_name)[1..]; + let path = Self::normalize_path(chunk_name.as_ref()); + + if path.extension() == Some("rs".as_ref()) { + let cwd = match env::current_dir() { + Ok(cwd) => cwd, + Err(_) => return Err(NavigateError::NotFound), + }; + self.abs_path.replace(Self::normalize_path(&cwd.join(&path))); + self.rel_path.replace(path); + self.module_path.replace(PathBuf::new()); + + return Ok(()); + } + + if path.is_absolute() { + let module_path = Self::find_module_path(&path)?; + self.abs_path.replace(path.clone()); + self.rel_path.replace(path); + self.module_path.replace(module_path); + } else { + // Relative path + let cwd = match env::current_dir() { + Ok(cwd) => cwd, + Err(_) => return Err(NavigateError::NotFound), + }; + let abs_path = cwd.join(&path); + let module_path = Self::find_module_path(&abs_path)?; + self.abs_path.replace(Self::normalize_path(&abs_path)); + self.rel_path.replace(path); + self.module_path.replace(module_path); + } + + Ok(()) + } + + fn jump_to_alias(&self, path: &str) -> StdResult<(), NavigateError> { + let path = Self::normalize_path(path.as_ref()); + let module_path = Self::find_module_path(&path)?; + + self.abs_path.replace(path.clone()); + self.rel_path.replace(path); + self.module_path.replace(module_path); + + Ok(()) + } + + fn to_parent(&self) -> StdResult<(), NavigateError> { + let mut abs_path = self.abs_path.borrow().clone(); + if !abs_path.pop() { + return Err(NavigateError::NotFound); + } + let mut rel_parent = self.rel_path.borrow().clone(); + rel_parent.pop(); + let module_path = Self::find_module_path(&abs_path)?; + + self.abs_path.replace(abs_path); + self.rel_path.replace(Self::normalize_path(&rel_parent)); + self.module_path.replace(module_path); + + Ok(()) + } + + fn to_child(&self, name: &str) -> StdResult<(), NavigateError> { + let abs_path = self.abs_path.borrow().join(name); + let rel_path = self.rel_path.borrow().join(name); + let module_path = Self::find_module_path(&abs_path)?; + + self.abs_path.replace(abs_path); + self.rel_path.replace(rel_path); + self.module_path.replace(module_path); + + Ok(()) + } + + fn is_module_present(&self) -> bool { + self.module_path.borrow().is_file() + } + + fn contents(&self) -> IoResult> { + fs::read(&*self.module_path.borrow()) + } + + fn chunk_name(&self) -> String { + format!("@{}", self.rel_path.borrow().display()) + } + + fn cache_key(&self) -> Vec { + self.module_path.borrow().display().to_string().into_bytes() + } + + fn is_config_present(&self) -> bool { + self.abs_path.borrow().join(".luaurc").is_file() + } + + fn config(&self) -> IoResult> { + fs::read(self.abs_path.borrow().join(".luaurc")) + } +} +