A better way to return things

This commit is contained in:
IQuant 2024-11-25 19:35:38 +03:00
parent d62fb39a93
commit f10815e903
4 changed files with 112 additions and 49 deletions

View file

@ -1,10 +1,12 @@
use std::num::NonZero;
pub mod lua; pub mod lua;
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct EntityID(pub isize); pub struct EntityID(pub NonZero<isize>);
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ComponentID(pub isize); pub struct ComponentID(pub NonZero<isize>);
pub struct Obj(pub usize); pub struct Obj(pub usize);
@ -14,6 +16,7 @@ noita_api_macro::generate_components!();
pub mod raw { pub mod raw {
use super::{Color, ComponentID, EntityID, Obj}; use super::{Color, ComponentID, EntityID, Obj};
use crate::lua::LuaGetValue;
use crate::lua::LuaPutValue; use crate::lua::LuaPutValue;
use std::borrow::Cow; use std::borrow::Cow;
@ -28,7 +31,7 @@ pub mod raw {
) -> eyre::Result<()> { ) -> eyre::Result<()> {
let lua = LuaState::current()?; let lua = LuaState::current()?;
lua.get_global(c"ComponentGetValue2"); lua.get_global(c"ComponentGetValue2");
lua.push_integer(component.0); lua.push_integer(component.0.into());
lua.push_string(field); lua.push_string(field);
lua.call(2, expected_results); lua.call(2, expected_results);
Ok(()) Ok(())

View file

@ -207,13 +207,13 @@ impl LuaPutValue for str {
impl LuaPutValue for EntityID { impl LuaPutValue for EntityID {
fn put(&self, lua: LuaState) { fn put(&self, lua: LuaState) {
self.0.put(lua); isize::from(self.0).put(lua);
} }
} }
impl LuaPutValue for ComponentID { impl LuaPutValue for ComponentID {
fn put(&self, lua: LuaState) { fn put(&self, lua: LuaState) {
self.0.put(lua); isize::from(self.0).put(lua);
} }
} }
@ -244,3 +244,95 @@ impl<T: LuaPutValue> LuaPutValue for Option<T> {
} }
} }
} }
/// Trait for arguments that can be retrieved from the lua stack.
pub(crate) trait LuaGetValue {
fn get(lua: LuaState, index: i32) -> eyre::Result<Self>
where
Self: Sized;
fn size() -> i32 {
1
}
}
impl LuaGetValue for i32 {
fn get(lua: LuaState, index: i32) -> eyre::Result<Self> {
Ok(lua.to_integer(index) as Self)
}
}
impl LuaGetValue for isize {
fn get(lua: LuaState, index: i32) -> eyre::Result<Self> {
Ok(lua.to_integer(index) as Self)
}
}
impl LuaGetValue for u32 {
fn get(lua: LuaState, index: i32) -> eyre::Result<Self> {
Ok(unsafe { mem::transmute(lua.to_integer(index) as i32) })
}
}
impl LuaGetValue for f32 {
fn get(lua: LuaState, index: i32) -> eyre::Result<Self> {
Ok(lua.to_number(index) as f32)
}
}
impl LuaGetValue for f64 {
fn get(lua: LuaState, index: i32) -> eyre::Result<Self> {
Ok(lua.to_number(index))
}
}
impl LuaGetValue for bool {
fn get(lua: LuaState, index: i32) -> eyre::Result<Self> {
Ok(lua.to_bool(index))
}
}
impl LuaGetValue for Option<EntityID> {
fn get(lua: LuaState, index: i32) -> eyre::Result<Self> {
let ent = lua.to_integer(index);
Ok(if ent == 0 {
None
} else {
Some(EntityID(ent.try_into().unwrap()))
})
}
}
impl LuaGetValue for Option<ComponentID> {
fn get(lua: LuaState, index: i32) -> eyre::Result<Self> {
let com = lua.to_integer(index);
Ok(if com == 0 {
None
} else {
Some(ComponentID(com.try_into().unwrap()))
})
}
}
impl LuaGetValue for Cow<'static, str> {
fn get(lua: LuaState, index: i32) -> eyre::Result<Self> {
Ok(lua.to_string(index)?.into())
}
}
impl LuaGetValue for () {
fn get(_lua: LuaState, _index: i32) -> eyre::Result<Self> {
Ok(())
}
}
impl LuaGetValue for Obj {
fn get(_lua: LuaState, _index: i32) -> eyre::Result<Self> {
todo!()
}
}
impl LuaGetValue for Color {
fn get(_lua: LuaState, _index: i32) -> eyre::Result<Self> {
todo!()
}
}

View file

@ -2,7 +2,6 @@ use std::ffi::CString;
use heck::ToSnekCase; use heck::ToSnekCase;
use proc_macro::TokenStream; use proc_macro::TokenStream;
use proc_macro2::Ident;
use quote::{format_ident, quote}; use quote::{format_ident, quote};
use serde::Deserialize; use serde::Deserialize;
@ -89,35 +88,11 @@ impl Typ2 {
fn as_rust_type_return(&self) -> proc_macro2::TokenStream { fn as_rust_type_return(&self) -> proc_macro2::TokenStream {
match self { match self {
Typ2::String => quote! {Cow<'static, str>}, Typ2::String => quote! {Cow<'static, str>},
Typ2::EntityID => quote! {Option<EntityID>},
Typ2::ComponentID => quote!(Option<ComponentID>),
_ => self.as_rust_type(), _ => self.as_rust_type(),
} }
} }
fn generate_lua_push(&self, arg_name: Ident) -> proc_macro2::TokenStream {
match self {
Typ2::Int => quote! {lua.push_integer(#arg_name as isize)},
Typ2::Number => quote! {lua.push_number(#arg_name)},
Typ2::String => quote! {lua.push_string(&#arg_name)},
Typ2::Bool => quote! {lua.push_bool(#arg_name)},
Typ2::EntityID => quote! {lua.push_integer(#arg_name.0 as isize)},
Typ2::ComponentID => quote! {lua.push_integer(#arg_name.0 as isize)},
Typ2::Obj => quote! { todo!() },
Typ2::Color => quote! { todo!() },
}
}
fn generate_lua_get(&self, index: i32) -> proc_macro2::TokenStream {
match self {
Typ2::Int => quote! {lua.to_integer(#index) as i32},
Typ2::Number => quote! {lua.to_number(#index)},
Typ2::String => quote! { lua.to_string(#index)?.into() },
Typ2::Bool => quote! {lua.to_bool(#index)},
Typ2::EntityID => quote! {EntityID(lua.to_integer(#index))},
Typ2::ComponentID => quote! {ComponentID(lua.to_integer(#index))},
Typ2::Obj => quote! { todo!() },
Typ2::Color => quote! { todo!() },
}
}
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -260,14 +235,6 @@ fn generate_code_for_api_fn(api_fn: ApiFn) -> proc_macro2::TokenStream {
// } // }
}; };
let ret_expr = if api_fn.rets.is_empty() {
quote! { () }
} else {
// TODO support for more than one return value.
let ret = api_fn.rets.first().unwrap();
ret.typ.generate_lua_get(1)
};
let fn_name_c = name_to_c_literal(api_fn.fn_name); let fn_name_c = name_to_c_literal(api_fn.fn_name);
let ret_count = api_fn.rets.len() as i32; let ret_count = api_fn.rets.len() as i32;
@ -285,7 +252,7 @@ fn generate_code_for_api_fn(api_fn: ApiFn) -> proc_macro2::TokenStream {
lua.call(last_non_empty+1, #ret_count); lua.call(last_non_empty+1, #ret_count);
let ret = Ok(#ret_expr); let ret = LuaGetValue::get(lua, -1);
lua.pop_last_n(#ret_count); lua.pop_last_n(#ret_count);
ret ret
} }

View file

@ -6,7 +6,7 @@ use std::{
}; };
use addr_grabber::{grab_addrs, grabbed_fns, grabbed_globals}; use addr_grabber::{grab_addrs, grabbed_fns, grabbed_globals};
use eyre::bail; use eyre::{bail, OptionExt};
use noita::{ntypes::Entity, pixel::NoitaPixelRun, ParticleWorldState}; use noita::{ntypes::Entity, pixel::NoitaPixelRun, ParticleWorldState};
use noita_api::lua::{lua_bindings::lua_State, LuaState, ValuesOnStack, LUA}; use noita_api::lua::{lua_bindings::lua_State, LuaState, ValuesOnStack, LUA};
@ -95,7 +95,8 @@ fn bench_fn(_lua: LuaState) -> eyre::Result<()> {
let start = Instant::now(); let start = Instant::now();
let iters = 10000; let iters = 10000;
for _ in 0..iters { for _ in 0..iters {
let player = noita_api::raw::entity_get_closest_with_tag(0.0, 0.0, "player_unit".into())?; let player = noita_api::raw::entity_get_closest_with_tag(0.0, 0.0, "player_unit".into())?
.ok_or_eyre("Entity not found")?;
noita_api::raw::entity_set_transform(player, 0.0, Some(0.0), None, None, None)?; noita_api::raw::entity_set_transform(player, 0.0, Some(0.0), None, None, None)?;
} }
let elapsed = start.elapsed(); let elapsed = start.elapsed();
@ -113,12 +114,12 @@ fn bench_fn(_lua: LuaState) -> eyre::Result<()> {
} }
fn test_fn(_lua: LuaState) -> eyre::Result<()> { fn test_fn(_lua: LuaState) -> eyre::Result<()> {
let player = noita_api::raw::entity_get_closest_with_tag(0.0, 0.0, "player_unit".into())?; let player = noita_api::raw::entity_get_closest_with_tag(0.0, 0.0, "player_unit".into())?
let damage_model = noita_api::DamageModelComponent(noita_api::raw::entity_get_first_component( .ok_or_eyre("Entity not found")?;
player, let damage_model = noita_api::DamageModelComponent(
"DamageModelComponent".into(), noita_api::raw::entity_get_first_component(player, "DamageModelComponent".into(), None)?
None, .ok_or_eyre("Could not find damage model")?,
)?); );
let hp = damage_model.hp()?; let hp = damage_model.hp()?;
noita_api::raw::game_print( noita_api::raw::game_print(