Wip optional arguments

This commit is contained in:
IQuant 2024-11-25 02:33:02 +03:00
parent 6102c30b2e
commit bac266e456
5 changed files with 157 additions and 42 deletions

View file

@ -14,6 +14,8 @@ enum Typ {
UInt, UInt,
#[serde(rename = "float")] #[serde(rename = "float")]
Float, Float,
#[serde(rename = "double")]
Double,
#[serde(rename = "bool")] #[serde(rename = "bool")]
Bool, Bool,
#[serde(rename = "std::string")] #[serde(rename = "std::string")]
@ -29,13 +31,25 @@ impl Typ {
match self { match self {
Typ::Int => quote!(i32), Typ::Int => quote!(i32),
Typ::UInt => quote!(u32), Typ::UInt => quote!(u32),
Typ::Float => quote!(f32), Typ::Float => quote!(f64),
Typ::Double => quote!(f64),
Typ::Bool => quote!(bool), Typ::Bool => quote!(bool),
Typ::StdString => todo!(), Typ::StdString => todo!(),
Typ::Vec2 => todo!(), Typ::Vec2 => todo!(),
Typ::Other => todo!(), Typ::Other => todo!(),
} }
} }
fn as_lua_type(&self) -> &'static str {
match self {
Typ::Int | Typ::UInt => "integer",
Typ::Float | Typ::Double => "number",
Typ::Bool => "bool",
Typ::StdString => todo!(),
Typ::Vec2 => todo!(),
Typ::Other => todo!(),
}
}
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -123,7 +137,7 @@ struct Component {
struct FnArg { struct FnArg {
name: String, name: String,
typ: Typ2, typ: Typ2,
// default: Option<String>, default: Option<String>,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -163,17 +177,22 @@ fn generate_code_for_component(com: Component) -> proc_macro2::TokenStream {
let component_name = format_ident!("{}", com.name); let component_name = format_ident!("{}", com.name);
let impls = com.fields.iter().filter_map(|field| { let impls = com.fields.iter().filter_map(|field| {
let field_name = format_ident!("{}", convert_field_name(&field.field)); let field_name_s = convert_field_name(&field.field);
let field_name = format_ident!("{}", field_name_s);
let field_doc = &field.desc; let field_doc = &field.desc;
let set_method_name = format_ident!("set_{}", field_name);
match field.typ { match field.typ {
Typ::Int | Typ::UInt | Typ::Float | Typ::Bool => { Typ::Int | Typ::UInt | Typ::Float | Typ::Double | Typ::Bool => {
let field_type = field.typ.as_rust_type(); let field_type = field.typ.as_rust_type();
let set_method_name = format_ident!("set_{}", field_name); let getter_fn = format_ident!("component_get_value_{}", field.typ.as_lua_type());
Some(quote! { Some(quote! {
#[doc = #field_doc] #[doc = #field_doc]
fn #field_name(self) -> #field_type { todo!() } pub fn #field_name(self) -> eyre::Result<#field_type> {
// This trasmute is used to reinterpret i32 as u32 in one case.
unsafe { std::mem::transmute(raw::#getter_fn(self.0, #field_name_s)) }
}
#[doc = #field_doc] #[doc = #field_doc]
fn #set_method_name(self, value: #field_type) { todo!() } pub fn #set_method_name(self, value: #field_type) -> eyre::Result<()> { todo!() }
}) })
} }
_ => None, _ => None,
@ -182,7 +201,7 @@ fn generate_code_for_component(com: Component) -> proc_macro2::TokenStream {
quote! { quote! {
#[derive(Clone, Copy, PartialEq, Eq)] #[derive(Clone, Copy, PartialEq, Eq)]
struct #component_name(u32); pub struct #component_name(pub(crate) ComponentID);
impl #component_name { impl #component_name {
#(#impls)* #(#impls)*
@ -197,14 +216,32 @@ fn generate_code_for_api_fn(api_fn: ApiFn) -> proc_macro2::TokenStream {
let args = api_fn.args.iter().map(|arg| { let args = api_fn.args.iter().map(|arg| {
let arg_name = format_ident!("{}", arg.name); let arg_name = format_ident!("{}", arg.name);
let arg_type = arg.typ.as_rust_type(); let arg_type = arg.typ.as_rust_type();
quote! { let optional = arg.default.is_some();
#arg_name: #arg_type if optional {
quote! {
#arg_name: Option<#arg_type>
}
} else {
quote! {
#arg_name: #arg_type
}
} }
}); });
let put_args = api_fn.args.iter().map(|arg| { let put_args = api_fn.args.iter().map(|arg| {
let optional = arg.default.is_some();
let arg_name = format_ident!("{}", arg.name); let arg_name = format_ident!("{}", arg.name);
arg.typ.generate_lua_push(arg_name) let arg_push = arg.typ.generate_lua_push(arg_name.clone());
if optional {
quote! {
match #arg_name {
Some(#arg_name) => #arg_push,
None => lua.push_nil(),
}
}
} else {
arg_push
}
}); });
let ret_type = if api_fn.rets.is_empty() { let ret_type = if api_fn.rets.is_empty() {
@ -242,7 +279,9 @@ fn generate_code_for_api_fn(api_fn: ApiFn) -> proc_macro2::TokenStream {
lua.call(#arg_count, #ret_count); lua.call(#arg_count, #ret_count);
Ok(#ret_expr) let ret = Ok(#ret_expr);
lua.pop_last_n(#ret_count);
ret
} }
} }
} }

View file

@ -1,3 +1,4 @@
use crate::noita::api::DamageModelComponent;
use std::{ use std::{
arch::asm, arch::asm,
cell::{LazyCell, RefCell}, cell::{LazyCell, RefCell},
@ -14,9 +15,9 @@ use noita::{ntypes::Entity, pixel::NoitaPixelRun, ParticleWorldState};
use noita_api_macro::add_lua_fn; use noita_api_macro::add_lua_fn;
mod lua_bindings; mod lua_bindings;
mod lua_state; pub mod lua_state;
mod noita; pub mod noita;
mod addr_grabber; mod addr_grabber;
@ -100,12 +101,12 @@ fn on_world_initialized(lua: LuaState) {
grab_addrs(lua); grab_addrs(lua);
} }
fn test_fn(_lua: LuaState) -> eyre::Result<()> { fn bench_fn(_lua: LuaState) -> eyre::Result<()> {
let start = Instant::now(); let start = Instant::now();
let iters = 10000; let iters = 10000;
for _ in 0..iters { for _ in 0..iters {
let player = noita::api::raw::entity_get_closest_with_tag(0.0, 0.0, "player_unit".into())?; let player = noita::api::raw::entity_get_closest_with_tag(0.0, 0.0, "player_unit".into())?;
noita::api::raw::entity_set_transform(player, 0.0, 0.0, 0.0, 1.0, 1.0)?; noita::api::raw::entity_set_transform(player, 0.0, Some(0.0), None, None, None)?;
} }
let elapsed = start.elapsed(); let elapsed = start.elapsed();
@ -121,6 +122,24 @@ fn test_fn(_lua: LuaState) -> eyre::Result<()> {
Ok(()) Ok(())
} }
fn test_fn(_lua: LuaState) -> eyre::Result<()> {
let player = noita::api::raw::entity_get_closest_with_tag(0.0, 0.0, "player_unit".into())?;
let damage_model = DamageModelComponent(noita::api::raw::entity_get_first_component(
player,
"DamageModelComponent".into(),
None,
)?);
let hp = damage_model.hp()?;
noita::api::raw::game_print(
format!("Component: {:?}, Hp: {}", damage_model.0, hp * 25.0).into(),
)?;
// noita::api::raw::entity_set_transform(player, 0.0, 0.0, 0.0, 1.0, 1.0)?;
Ok(())
}
/// # Safety /// # Safety
/// ///
/// Only gets called by lua when loading a module. /// Only gets called by lua when loading a module.
@ -135,6 +154,7 @@ pub unsafe extern "C" fn luaopen_ewext0(lua: *mut lua_State) -> c_int {
add_lua_fn!(make_ephemerial); add_lua_fn!(make_ephemerial);
add_lua_fn!(on_world_initialized); add_lua_fn!(on_world_initialized);
add_lua_fn!(test_fn); add_lua_fn!(test_fn);
add_lua_fn!(bench_fn);
} }
println!("Initializing ewext - Ok"); println!("Initializing ewext - Ok");
1 1

View file

@ -16,24 +16,23 @@ thread_local! {
} }
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
pub(crate) struct LuaState { pub struct LuaState {
lua: *mut lua_State, lua: *mut lua_State,
} }
#[expect(dead_code)]
impl LuaState { impl LuaState {
pub(crate) fn new(lua: *mut lua_State) -> Self { pub fn new(lua: *mut lua_State) -> Self {
Self { lua } Self { lua }
} }
/// Returns a lua state that is considered "current". Usually set when we get called from noita. /// Returns a lua state that is considered "current". Usually set when we get called from noita.
pub(crate) fn current() -> eyre::Result<Self> { pub fn current() -> eyre::Result<Self> {
CURRENT_LUA_STATE CURRENT_LUA_STATE
.get() .get()
.ok_or_eyre("No current lua state available") .ok_or_eyre("No current lua state available")
} }
pub(crate) fn make_current(self) { pub fn make_current(self) {
CURRENT_LUA_STATE.set(Some(self)); CURRENT_LUA_STATE.set(Some(self));
} }
@ -41,19 +40,19 @@ impl LuaState {
self.lua self.lua
} }
pub(crate) fn to_integer(&self, index: i32) -> isize { pub fn to_integer(&self, index: i32) -> isize {
unsafe { LUA.lua_tointeger(self.lua, index) } unsafe { LUA.lua_tointeger(self.lua, index) }
} }
pub(crate) fn to_number(&self, index: i32) -> f64 { pub fn to_number(&self, index: i32) -> f64 {
unsafe { LUA.lua_tonumber(self.lua, index) } unsafe { LUA.lua_tonumber(self.lua, index) }
} }
pub(crate) fn to_bool(&self, index: i32) -> bool { pub fn to_bool(&self, index: i32) -> bool {
unsafe { LUA.lua_toboolean(self.lua, index) > 0 } unsafe { LUA.lua_toboolean(self.lua, index) > 0 }
} }
pub(crate) fn to_string(&self, index: i32) -> eyre::Result<String> { pub fn to_string(&self, index: i32) -> eyre::Result<String> {
let mut size = 0; let mut size = 0;
let buf = unsafe { LUA.lua_tolstring(self.lua, index, &mut size) }; let buf = unsafe { LUA.lua_tolstring(self.lua, index, &mut size) };
if buf.is_null() { if buf.is_null() {
@ -65,39 +64,46 @@ impl LuaState {
.context("Attempting to get lua string, expecting it to be utf-8")?) .context("Attempting to get lua string, expecting it to be utf-8")?)
} }
pub(crate) fn to_cfunction(&self, index: i32) -> lua_CFunction { pub fn to_cfunction(&self, index: i32) -> lua_CFunction {
unsafe { LUA.lua_tocfunction(self.lua, index) } unsafe { LUA.lua_tocfunction(self.lua, index) }
} }
pub(crate) fn push_number(&self, val: f64) { pub fn push_number(&self, val: f64) {
unsafe { LUA.lua_pushnumber(self.lua, val) }; unsafe { LUA.lua_pushnumber(self.lua, val) };
} }
pub(crate) fn push_integer(&self, val: isize) { pub fn push_integer(&self, val: isize) {
unsafe { LUA.lua_pushinteger(self.lua, val) }; unsafe { LUA.lua_pushinteger(self.lua, val) };
} }
pub(crate) fn push_bool(&self, val: bool) { pub fn push_bool(&self, val: bool) {
unsafe { LUA.lua_pushboolean(self.lua, val as i32) }; unsafe { LUA.lua_pushboolean(self.lua, val as i32) };
} }
pub(crate) fn push_string(&self, s: &str) { pub fn push_string(&self, s: &str) {
unsafe { unsafe {
LUA.lua_pushlstring(self.lua, s.as_bytes().as_ptr() as *const c_char, s.len()); LUA.lua_pushlstring(self.lua, s.as_bytes().as_ptr() as *const c_char, s.len());
} }
} }
pub(crate) fn call(&self, nargs: i32, nresults: i32) { pub fn push_nil(&self) {
unsafe { LUA.lua_pushnil(self.lua) }
}
pub fn call(&self, nargs: i32, nresults: i32) {
unsafe { LUA.lua_call(self.lua, nargs, nresults) }; unsafe { LUA.lua_call(self.lua, nargs, nresults) };
} }
pub(crate) fn get_global(&self, name: &CStr) { pub fn get_global(&self, name: &CStr) {
unsafe { LUA.lua_getfield(self.lua, LUA_GLOBALSINDEX, name.as_ptr()) }; unsafe { LUA.lua_getfield(self.lua, LUA_GLOBALSINDEX, name.as_ptr()) };
} }
pub(crate) fn pop_last(&self) { pub fn pop_last(&self) {
unsafe { LUA.lua_settop(self.lua, -2) }; unsafe { LUA.lua_settop(self.lua, -2) };
} }
pub fn pop_last_n(&self, n: i32) {
unsafe { LUA.lua_settop(self.lua, -1 - (n)) };
}
/// Raise an error with message `s` /// Raise an error with message `s`
/// ///

View file

@ -3,23 +3,72 @@ use std::{ffi::c_void, mem};
pub(crate) mod ntypes; pub(crate) mod ntypes;
pub(crate) mod pixel; pub(crate) mod pixel;
pub(crate) mod api { pub mod api {
pub(crate) struct EntityID(isize); #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct ComponentID(isize); pub struct EntityID(isize);
pub(crate) struct Obj(usize); #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ComponentID(isize);
pub(crate) struct Color(u32); pub struct Obj(usize);
pub struct Color(u32);
noita_api_macro::generate_components!(); noita_api_macro::generate_components!();
pub(crate) mod raw { pub mod raw {
use super::{Color, ComponentID, EntityID, Obj}; use super::{Color, ComponentID, EntityID, Obj};
use std::borrow::Cow; use std::borrow::Cow;
use crate::LuaState; use crate::LuaState;
noita_api_macro::generate_api!(); noita_api_macro::generate_api!();
fn component_get_value_base(
component: ComponentID,
field: &str,
expected_results: i32,
) -> eyre::Result<()> {
let lua = LuaState::current()?;
lua.get_global(c"ComponentGetValue2");
lua.push_integer(component.0);
lua.push_string(field);
lua.call(2, expected_results);
Ok(())
}
pub(crate) fn component_get_value_number(
component: ComponentID,
field: &str,
) -> eyre::Result<f64> {
component_get_value_base(component, field, 1)?;
let lua = LuaState::current()?;
let ret = lua.to_number(1);
lua.pop_last();
Ok(ret)
}
pub(crate) fn component_get_value_integer(
component: ComponentID,
field: &str,
) -> eyre::Result<i32> {
component_get_value_base(component, field, 1)?;
let lua = LuaState::current()?;
let ret = lua.to_integer(1);
lua.pop_last();
Ok(ret as i32)
}
pub(crate) fn component_get_value_bool(
component: ComponentID,
field: &str,
) -> eyre::Result<bool> {
component_get_value_base(component, field, 1)?;
let lua = LuaState::current()?;
let ret = lua.to_bool(1);
lua.pop_last();
Ok(ret)
}
} }
} }

View file

@ -39,7 +39,7 @@ local function fw_button(label)
return imgui.Button(label, imgui.GetWindowWidth() - 15, 20) return imgui.Button(label, imgui.GetWindowWidth() - 15, 20)
end end
local function test_fn_lua() local function bench_fn_lua()
local start = GameGetRealWorldTimeSinceStarted() local start = GameGetRealWorldTimeSinceStarted()
for i=1,10000 do for i=1,10000 do
local player = EntityGetClosestWithTag(0, 0, "player_unit") local player = EntityGetClosestWithTag(0, 0, "player_unit")
@ -54,8 +54,9 @@ function module.on_draw_debug_window(imgui)
if fw_button("test_fn") then if fw_button("test_fn") then
ewext.test_fn() ewext.test_fn()
end end
if fw_button("test_fn_lua") then if fw_button("bench") then
test_fn_lua() ewext.bench_fn()
bench_fn_lua()
end end
end end
end end