LibWasm: Add support for proposal 'tail-call'

This commit is contained in:
Ali Mohammad Pur 2025-09-22 11:38:44 +02:00 committed by Ali Mohammad Pur
parent d065171791
commit 6a6f747701
Notes: github-actions[bot] 2025-10-14 23:29:38 +00:00
9 changed files with 195 additions and 43 deletions

View file

@ -68,7 +68,7 @@ struct ConvertToRaw<double> {
do { \
if (trap_if_not(x, #x##sv __VA_OPT__(, ) __VA_ARGS__)) { \
dbgln_if(WASM_TRACE_DEBUG, "Trapped because {} failed, at line {}", #x, __LINE__); \
return true; \
return Outcome::Return; \
} \
} while (false)
@ -98,11 +98,7 @@ void BytecodeInterpreter::interpret(Configuration& configuration)
return interpret_impl<false, false, false>(configuration, expression);
}
enum class Outcome : u64 {
// 0..Constants::max_allowed_executed_instructions_per_call -> next IP.
Continue = Constants::max_allowed_executed_instructions_per_call + 1,
Return,
};
constexpr static u32 default_sources_and_destination = (to_underlying(Dispatch::RegisterOrStack::Stack) | (to_underlying(Dispatch::RegisterOrStack::Stack) << 2) | (to_underlying(Dispatch::RegisterOrStack::Stack) << 4));
template<u64 opcode>
struct InstructionHandler { };
@ -1071,7 +1067,7 @@ HANDLE_INSTRUCTION(synthetic_call_00)
auto index = instruction->arguments().get<FunctionIndex>();
auto address = configuration.frame().module().functions()[index.value()];
dbgln_if(WASM_TRACE_DEBUG, "[{}] call(#{} -> {})", current_ip_value, index.value(), address.value());
if (interpreter.call_address(configuration, address))
if (interpreter.call_address(configuration, address) == Outcome::Return)
return Outcome::Return;
configuration.regs = regs_copy;
TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY));
@ -1083,7 +1079,7 @@ HANDLE_INSTRUCTION(synthetic_call_01)
auto index = instruction->arguments().get<FunctionIndex>();
auto address = configuration.frame().module().functions()[index.value()];
dbgln_if(WASM_TRACE_DEBUG, "[{}] call(#{} -> {})", current_ip_value, index.value(), address.value());
if (interpreter.call_address(configuration, address))
if (interpreter.call_address(configuration, address) == Outcome::Return)
return Outcome::Return;
configuration.regs = regs_copy;
TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY));
@ -1095,7 +1091,7 @@ HANDLE_INSTRUCTION(synthetic_call_10)
auto index = instruction->arguments().get<FunctionIndex>();
auto address = configuration.frame().module().functions()[index.value()];
dbgln_if(WASM_TRACE_DEBUG, "[{}] call(#{} -> {})", current_ip_value, index.value(), address.value());
if (interpreter.call_address(configuration, address))
if (interpreter.call_address(configuration, address) == Outcome::Return)
return Outcome::Return;
configuration.regs = regs_copy;
TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY));
@ -1107,7 +1103,7 @@ HANDLE_INSTRUCTION(synthetic_call_11)
auto index = instruction->arguments().get<FunctionIndex>();
auto address = configuration.frame().module().functions()[index.value()];
dbgln_if(WASM_TRACE_DEBUG, "[{}] call(#{} -> {})", current_ip_value, index.value(), address.value());
if (interpreter.call_address(configuration, address))
if (interpreter.call_address(configuration, address) == Outcome::Return)
return Outcome::Return;
configuration.regs = regs_copy;
TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY));
@ -1119,7 +1115,7 @@ HANDLE_INSTRUCTION(synthetic_call_20)
auto index = instruction->arguments().get<FunctionIndex>();
auto address = configuration.frame().module().functions()[index.value()];
dbgln_if(WASM_TRACE_DEBUG, "[{}] call(#{} -> {})", current_ip_value, index.value(), address.value());
if (interpreter.call_address(configuration, address))
if (interpreter.call_address(configuration, address) == Outcome::Return)
return Outcome::Return;
configuration.regs = regs_copy;
TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY));
@ -1131,7 +1127,7 @@ HANDLE_INSTRUCTION(synthetic_call_21)
auto index = instruction->arguments().get<FunctionIndex>();
auto address = configuration.frame().module().functions()[index.value()];
dbgln_if(WASM_TRACE_DEBUG, "[{}] call(#{} -> {})", current_ip_value, index.value(), address.value());
if (interpreter.call_address(configuration, address))
if (interpreter.call_address(configuration, address) == Outcome::Return)
return Outcome::Return;
configuration.regs = regs_copy;
TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY));
@ -1143,7 +1139,7 @@ HANDLE_INSTRUCTION(synthetic_call_30)
auto index = instruction->arguments().get<FunctionIndex>();
auto address = configuration.frame().module().functions()[index.value()];
dbgln_if(WASM_TRACE_DEBUG, "[{}] call(#{} -> {})", current_ip_value, index.value(), address.value());
if (interpreter.call_address(configuration, address))
if (interpreter.call_address(configuration, address) == Outcome::Return)
return Outcome::Return;
configuration.regs = regs_copy;
TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY));
@ -1155,7 +1151,7 @@ HANDLE_INSTRUCTION(synthetic_call_31)
auto index = instruction->arguments().get<FunctionIndex>();
auto address = configuration.frame().module().functions()[index.value()];
dbgln_if(WASM_TRACE_DEBUG, "[{}] call(#{} -> {})", current_ip_value, index.value(), address.value());
if (interpreter.call_address(configuration, address))
if (interpreter.call_address(configuration, address) == Outcome::Return)
return Outcome::Return;
configuration.regs = regs_copy;
TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY));
@ -1321,11 +1317,31 @@ HANDLE_INSTRUCTION(call)
auto index = instruction->arguments().get<FunctionIndex>();
auto address = configuration.frame().module().functions()[index.value()];
dbgln_if(WASM_TRACE_DEBUG, "call({})", address.value());
if (interpreter.call_address(configuration, address))
if (interpreter.call_address(configuration, address) == Outcome::Return)
return Outcome::Return;
TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY));
}
HANDLE_INSTRUCTION(return_call)
{
auto index = instruction->arguments().get<FunctionIndex>();
auto address = configuration.frame().module().functions()[index.value()];
configuration.label_stack().shrink(configuration.frame().label_index() + 1, true);
dbgln_if(WASM_TRACE_DEBUG, "tail call({})", address.value());
switch (auto const outcome = interpreter.call_address(configuration, address, BytecodeInterpreter::CallAddressSource::DirectTailCall)) {
default:
// Some IP we have to continue from.
current_ip_value = to_underlying(outcome) - 1;
addresses = { .sources_and_destination = default_sources_and_destination };
cc = configuration.frame().expression().compiled_instructions.dispatches.data();
[[fallthrough]];
case Outcome::Continue:
TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY));
case Outcome::Return:
return Outcome::Return;
}
}
HANDLE_INSTRUCTION(call_indirect)
{
auto& args = instruction->arguments().get<Instruction::IndirectCallArgs>();
@ -1346,11 +1362,45 @@ HANDLE_INSTRUCTION(call_indirect)
TRAP_IN_LOOP_IF_NOT(type_actual.results() == type_expected.results());
dbgln_if(WASM_TRACE_DEBUG, "call_indirect({} -> {})", index, address.value());
if (interpreter.call_address(configuration, address, BytecodeInterpreter::CallAddressSource::IndirectCall))
if (interpreter.call_address(configuration, address, BytecodeInterpreter::CallAddressSource::IndirectCall) == Outcome::Return)
return Outcome::Return;
TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY));
}
HANDLE_INSTRUCTION(return_call_indirect)
{
auto& args = instruction->arguments().get<Instruction::IndirectCallArgs>();
auto table_address = configuration.frame().module().tables()[args.table.value()];
auto table_instance = configuration.store().get(table_address);
// bounds checked by verifier.
auto index = configuration.take_source(0, addresses.sources).to<i32>();
TRAP_IN_LOOP_IF_NOT(index >= 0);
TRAP_IN_LOOP_IF_NOT(static_cast<size_t>(index) < table_instance->elements().size());
auto& element = table_instance->elements()[index];
TRAP_IN_LOOP_IF_NOT(element.ref().has<Reference::Func>());
auto address = element.ref().get<Reference::Func>().address;
auto const& type_actual = configuration.store().get(address)->visit([](auto& f) -> decltype(auto) { return f.type(); });
auto const& type_expected = configuration.frame().module().types()[args.type.value()];
TRAP_IN_LOOP_IF_NOT(type_actual.parameters().size() == type_expected.parameters().size());
TRAP_IN_LOOP_IF_NOT(type_actual.results().size() == type_expected.results().size());
TRAP_IN_LOOP_IF_NOT(type_actual.parameters() == type_expected.parameters());
TRAP_IN_LOOP_IF_NOT(type_actual.results() == type_expected.results());
dbgln_if(WASM_TRACE_DEBUG, "tail call_indirect({} -> {})", index, address.value());
switch (auto const outcome = interpreter.call_address(configuration, address, BytecodeInterpreter::CallAddressSource::IndirectTailCall)) {
default:
// Some IP we have to continue from.
current_ip_value = to_underlying(outcome) - 1;
addresses = { .sources_and_destination = default_sources_and_destination };
cc = configuration.frame().expression().compiled_instructions.dispatches.data();
[[fallthrough]];
case Outcome::Continue:
TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY));
case Outcome::Return:
return Outcome::Return;
}
}
HANDLE_INSTRUCTION(i32_load)
{
if (interpreter.load_and_push<i32, i32>(configuration, *instruction, addresses))
@ -3633,10 +3683,9 @@ FLATTEN void BytecodeInterpreter::interpret_impl(Configuration& configuration, E
auto current_ip_value = configuration.ip();
u64 executed_instructions = 0;
constexpr static u32 default_sources_and_destination = (to_underlying(Dispatch::RegisterOrStack::Stack) | (to_underlying(Dispatch::RegisterOrStack::Stack) << 2) | (to_underlying(Dispatch::RegisterOrStack::Stack) << 4));
SourcesAndDestination addresses { .sources_and_destination = default_sources_and_destination };
auto const cc = expression.compiled_instructions.dispatches.data();
auto cc = expression.compiled_instructions.dispatches.data();
if constexpr (HaveDirectThreadingInfo) {
static_assert(HasCompiledList, "Direct threading requires a compiled instruction list");
@ -3678,6 +3727,8 @@ FLATTEN void BytecodeInterpreter::interpret_impl(Configuration& configuration, E
if (outcome == Outcome::Return) \
return; \
current_ip_value = to_underlying(outcome); \
if constexpr (Instructions::name == Instructions::return_call || Instructions::name == Instructions::return_call_indirect) \
cc = configuration.frame().expression().compiled_instructions.dispatches.data(); \
RUN_NEXT_INSTRUCTION(); \
}
@ -3869,14 +3920,14 @@ VectorType BytecodeInterpreter::pop_vector(Configuration& configuration, size_t
return bit_cast<VectorType>(configuration.take_source(source, addresses.sources).to<u128>());
}
bool BytecodeInterpreter::call_address(Configuration& configuration, FunctionAddress address, CallAddressSource source)
Outcome BytecodeInterpreter::call_address(Configuration& configuration, FunctionAddress address, CallAddressSource source)
{
TRAP_IF_NOT(m_stack_info.size_free() >= Constants::minimum_stack_space_to_keep_free, "{}: {}", Constants::stack_exhaustion_message);
auto instance = configuration.store().get(address);
FunctionType const* type { nullptr };
instance->visit([&](auto const& function) { type = &function.type(); });
if (source == CallAddressSource::IndirectCall) {
if (source == CallAddressSource::IndirectCall || source == CallAddressSource::IndirectTailCall) {
TRAP_IF_NOT(type->parameters().size() <= configuration.value_stack().size());
}
Vector<Value> args;
@ -3890,16 +3941,34 @@ bool BytecodeInterpreter::call_address(Configuration& configuration, FunctionAdd
}
Result result { Trap::from_string("") };
if (instance->has<WasmFunction>()) {
CallFrameHandle handle { *this, configuration };
result = configuration.call(*this, address, move(args));
Outcome final_outcome = Outcome::Continue;
if (source == CallAddressSource::DirectTailCall || source == CallAddressSource::IndirectTailCall) {
auto prep_outcome = configuration.prepare_call(address, args, true);
if (prep_outcome.is_error()) {
m_trap = prep_outcome.release_error();
return Outcome::Return;
}
final_outcome = Outcome::Return; // At this point we can only ever return (unless we succeed in tail-calling).
if (prep_outcome.value().has_value()) {
result = prep_outcome.value()->function()(configuration, args);
} else {
configuration.ip() = 0;
return static_cast<Outcome>(0); // Continue from IP 0 in the new frame.
}
} else {
result = configuration.call(*this, address, move(args));
if (instance->has<WasmFunction>()) {
CallFrameHandle handle { *this, configuration };
result = configuration.call(*this, address, move(args));
} else {
result = configuration.call(*this, address, move(args));
}
}
if (result.is_trap()) {
m_trap = move(result.trap());
return true;
return Outcome::Return;
}
if (!result.values().is_empty()) {
@ -3908,7 +3977,7 @@ bool BytecodeInterpreter::call_address(Configuration& configuration, FunctionAdd
configuration.value_stack().unchecked_append(entry);
}
return false;
return final_outcome;
}
template<typename PopTypeLHS, typename PushType, typename Operator, typename PopTypeRHS, typename... Args>