From f10815e903ba29888ac972e5b7e10c40bd938937 Mon Sep 17 00:00:00 2001 From: IQuant Date: Mon, 25 Nov 2024 19:35:38 +0300 Subject: [PATCH] A better way to return things --- ewext/noita_api/src/lib.rs | 9 ++- ewext/noita_api/src/lua.rs | 96 +++++++++++++++++++++++++++++++- ewext/noita_api_macro/src/lib.rs | 39 +------------ ewext/src/lib.rs | 17 +++--- 4 files changed, 112 insertions(+), 49 deletions(-) diff --git a/ewext/noita_api/src/lib.rs b/ewext/noita_api/src/lib.rs index 20b8c2b1..7b347f29 100644 --- a/ewext/noita_api/src/lib.rs +++ b/ewext/noita_api/src/lib.rs @@ -1,10 +1,12 @@ +use std::num::NonZero; + pub mod lua; #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct EntityID(pub isize); +pub struct EntityID(pub NonZero); #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct ComponentID(pub isize); +pub struct ComponentID(pub NonZero); pub struct Obj(pub usize); @@ -14,6 +16,7 @@ noita_api_macro::generate_components!(); pub mod raw { use super::{Color, ComponentID, EntityID, Obj}; + use crate::lua::LuaGetValue; use crate::lua::LuaPutValue; use std::borrow::Cow; @@ -28,7 +31,7 @@ pub mod raw { ) -> eyre::Result<()> { let lua = LuaState::current()?; lua.get_global(c"ComponentGetValue2"); - lua.push_integer(component.0); + lua.push_integer(component.0.into()); lua.push_string(field); lua.call(2, expected_results); Ok(()) diff --git a/ewext/noita_api/src/lua.rs b/ewext/noita_api/src/lua.rs index ef544f37..c9e1bd82 100644 --- a/ewext/noita_api/src/lua.rs +++ b/ewext/noita_api/src/lua.rs @@ -207,13 +207,13 @@ impl LuaPutValue for str { impl LuaPutValue for EntityID { fn put(&self, lua: LuaState) { - self.0.put(lua); + isize::from(self.0).put(lua); } } impl LuaPutValue for ComponentID { fn put(&self, lua: LuaState) { - self.0.put(lua); + isize::from(self.0).put(lua); } } @@ -244,3 +244,95 @@ impl LuaPutValue for Option { } } } + +/// Trait for arguments that can be retrieved from the lua stack. +pub(crate) trait LuaGetValue { + fn get(lua: LuaState, index: i32) -> eyre::Result + where + Self: Sized; + fn size() -> i32 { + 1 + } +} + +impl LuaGetValue for i32 { + fn get(lua: LuaState, index: i32) -> eyre::Result { + Ok(lua.to_integer(index) as Self) + } +} + +impl LuaGetValue for isize { + fn get(lua: LuaState, index: i32) -> eyre::Result { + Ok(lua.to_integer(index) as Self) + } +} + +impl LuaGetValue for u32 { + fn get(lua: LuaState, index: i32) -> eyre::Result { + Ok(unsafe { mem::transmute(lua.to_integer(index) as i32) }) + } +} + +impl LuaGetValue for f32 { + fn get(lua: LuaState, index: i32) -> eyre::Result { + Ok(lua.to_number(index) as f32) + } +} + +impl LuaGetValue for f64 { + fn get(lua: LuaState, index: i32) -> eyre::Result { + Ok(lua.to_number(index)) + } +} + +impl LuaGetValue for bool { + fn get(lua: LuaState, index: i32) -> eyre::Result { + Ok(lua.to_bool(index)) + } +} + +impl LuaGetValue for Option { + fn get(lua: LuaState, index: i32) -> eyre::Result { + let ent = lua.to_integer(index); + Ok(if ent == 0 { + None + } else { + Some(EntityID(ent.try_into().unwrap())) + }) + } +} + +impl LuaGetValue for Option { + fn get(lua: LuaState, index: i32) -> eyre::Result { + let com = lua.to_integer(index); + Ok(if com == 0 { + None + } else { + Some(ComponentID(com.try_into().unwrap())) + }) + } +} + +impl LuaGetValue for Cow<'static, str> { + fn get(lua: LuaState, index: i32) -> eyre::Result { + Ok(lua.to_string(index)?.into()) + } +} + +impl LuaGetValue for () { + fn get(_lua: LuaState, _index: i32) -> eyre::Result { + Ok(()) + } +} + +impl LuaGetValue for Obj { + fn get(_lua: LuaState, _index: i32) -> eyre::Result { + todo!() + } +} + +impl LuaGetValue for Color { + fn get(_lua: LuaState, _index: i32) -> eyre::Result { + todo!() + } +} diff --git a/ewext/noita_api_macro/src/lib.rs b/ewext/noita_api_macro/src/lib.rs index 0dab6765..5a65c503 100644 --- a/ewext/noita_api_macro/src/lib.rs +++ b/ewext/noita_api_macro/src/lib.rs @@ -2,7 +2,6 @@ use std::ffi::CString; use heck::ToSnekCase; use proc_macro::TokenStream; -use proc_macro2::Ident; use quote::{format_ident, quote}; use serde::Deserialize; @@ -89,35 +88,11 @@ impl Typ2 { fn as_rust_type_return(&self) -> proc_macro2::TokenStream { match self { Typ2::String => quote! {Cow<'static, str>}, + Typ2::EntityID => quote! {Option}, + Typ2::ComponentID => quote!(Option), _ => self.as_rust_type(), } } - - fn generate_lua_push(&self, arg_name: Ident) -> proc_macro2::TokenStream { - match self { - Typ2::Int => quote! {lua.push_integer(#arg_name as isize)}, - Typ2::Number => quote! {lua.push_number(#arg_name)}, - Typ2::String => quote! {lua.push_string(&#arg_name)}, - Typ2::Bool => quote! {lua.push_bool(#arg_name)}, - Typ2::EntityID => quote! {lua.push_integer(#arg_name.0 as isize)}, - Typ2::ComponentID => quote! {lua.push_integer(#arg_name.0 as isize)}, - Typ2::Obj => quote! { todo!() }, - Typ2::Color => quote! { todo!() }, - } - } - - fn generate_lua_get(&self, index: i32) -> proc_macro2::TokenStream { - match self { - Typ2::Int => quote! {lua.to_integer(#index) as i32}, - Typ2::Number => quote! {lua.to_number(#index)}, - Typ2::String => quote! { lua.to_string(#index)?.into() }, - Typ2::Bool => quote! {lua.to_bool(#index)}, - Typ2::EntityID => quote! {EntityID(lua.to_integer(#index))}, - Typ2::ComponentID => quote! {ComponentID(lua.to_integer(#index))}, - Typ2::Obj => quote! { todo!() }, - Typ2::Color => quote! { todo!() }, - } - } } #[derive(Deserialize)] @@ -260,14 +235,6 @@ fn generate_code_for_api_fn(api_fn: ApiFn) -> proc_macro2::TokenStream { // } }; - let ret_expr = if api_fn.rets.is_empty() { - quote! { () } - } else { - // TODO support for more than one return value. - let ret = api_fn.rets.first().unwrap(); - ret.typ.generate_lua_get(1) - }; - let fn_name_c = name_to_c_literal(api_fn.fn_name); let ret_count = api_fn.rets.len() as i32; @@ -285,7 +252,7 @@ fn generate_code_for_api_fn(api_fn: ApiFn) -> proc_macro2::TokenStream { lua.call(last_non_empty+1, #ret_count); - let ret = Ok(#ret_expr); + let ret = LuaGetValue::get(lua, -1); lua.pop_last_n(#ret_count); ret } diff --git a/ewext/src/lib.rs b/ewext/src/lib.rs index cdfaca22..ceb635c3 100644 --- a/ewext/src/lib.rs +++ b/ewext/src/lib.rs @@ -6,7 +6,7 @@ use std::{ }; use addr_grabber::{grab_addrs, grabbed_fns, grabbed_globals}; -use eyre::bail; +use eyre::{bail, OptionExt}; use noita::{ntypes::Entity, pixel::NoitaPixelRun, ParticleWorldState}; use noita_api::lua::{lua_bindings::lua_State, LuaState, ValuesOnStack, LUA}; @@ -95,7 +95,8 @@ fn bench_fn(_lua: LuaState) -> eyre::Result<()> { let start = Instant::now(); let iters = 10000; for _ in 0..iters { - let player = noita_api::raw::entity_get_closest_with_tag(0.0, 0.0, "player_unit".into())?; + let player = noita_api::raw::entity_get_closest_with_tag(0.0, 0.0, "player_unit".into())? + .ok_or_eyre("Entity not found")?; noita_api::raw::entity_set_transform(player, 0.0, Some(0.0), None, None, None)?; } let elapsed = start.elapsed(); @@ -113,12 +114,12 @@ fn bench_fn(_lua: LuaState) -> eyre::Result<()> { } fn test_fn(_lua: LuaState) -> eyre::Result<()> { - let player = noita_api::raw::entity_get_closest_with_tag(0.0, 0.0, "player_unit".into())?; - let damage_model = noita_api::DamageModelComponent(noita_api::raw::entity_get_first_component( - player, - "DamageModelComponent".into(), - None, - )?); + let player = noita_api::raw::entity_get_closest_with_tag(0.0, 0.0, "player_unit".into())? + .ok_or_eyre("Entity not found")?; + let damage_model = noita_api::DamageModelComponent( + noita_api::raw::entity_get_first_component(player, "DamageModelComponent".into(), None)? + .ok_or_eyre("Could not find damage model")?, + ); let hp = damage_model.hp()?; noita_api::raw::game_print(