LibWasm: Add support for proposal 'tail-call'

This commit is contained in:
Ali Mohammad Pur 2025-09-22 11:38:44 +02:00 committed by Ali Mohammad Pur
parent d065171791
commit 6a6f747701
Notes: github-actions[bot] 2025-10-14 23:29:38 +00:00
9 changed files with 195 additions and 43 deletions

View file

@ -68,7 +68,7 @@ struct ConvertToRaw<double> {
do { \ do { \
if (trap_if_not(x, #x##sv __VA_OPT__(, ) __VA_ARGS__)) { \ if (trap_if_not(x, #x##sv __VA_OPT__(, ) __VA_ARGS__)) { \
dbgln_if(WASM_TRACE_DEBUG, "Trapped because {} failed, at line {}", #x, __LINE__); \ dbgln_if(WASM_TRACE_DEBUG, "Trapped because {} failed, at line {}", #x, __LINE__); \
return true; \ return Outcome::Return; \
} \ } \
} while (false) } while (false)
@ -98,11 +98,7 @@ void BytecodeInterpreter::interpret(Configuration& configuration)
return interpret_impl<false, false, false>(configuration, expression); return interpret_impl<false, false, false>(configuration, expression);
} }
enum class Outcome : u64 { constexpr static u32 default_sources_and_destination = (to_underlying(Dispatch::RegisterOrStack::Stack) | (to_underlying(Dispatch::RegisterOrStack::Stack) << 2) | (to_underlying(Dispatch::RegisterOrStack::Stack) << 4));
// 0..Constants::max_allowed_executed_instructions_per_call -> next IP.
Continue = Constants::max_allowed_executed_instructions_per_call + 1,
Return,
};
template<u64 opcode> template<u64 opcode>
struct InstructionHandler { }; struct InstructionHandler { };
@ -1071,7 +1067,7 @@ HANDLE_INSTRUCTION(synthetic_call_00)
auto index = instruction->arguments().get<FunctionIndex>(); auto index = instruction->arguments().get<FunctionIndex>();
auto address = configuration.frame().module().functions()[index.value()]; auto address = configuration.frame().module().functions()[index.value()];
dbgln_if(WASM_TRACE_DEBUG, "[{}] call(#{} -> {})", current_ip_value, index.value(), address.value()); dbgln_if(WASM_TRACE_DEBUG, "[{}] call(#{} -> {})", current_ip_value, index.value(), address.value());
if (interpreter.call_address(configuration, address)) if (interpreter.call_address(configuration, address) == Outcome::Return)
return Outcome::Return; return Outcome::Return;
configuration.regs = regs_copy; configuration.regs = regs_copy;
TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY)); TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY));
@ -1083,7 +1079,7 @@ HANDLE_INSTRUCTION(synthetic_call_01)
auto index = instruction->arguments().get<FunctionIndex>(); auto index = instruction->arguments().get<FunctionIndex>();
auto address = configuration.frame().module().functions()[index.value()]; auto address = configuration.frame().module().functions()[index.value()];
dbgln_if(WASM_TRACE_DEBUG, "[{}] call(#{} -> {})", current_ip_value, index.value(), address.value()); dbgln_if(WASM_TRACE_DEBUG, "[{}] call(#{} -> {})", current_ip_value, index.value(), address.value());
if (interpreter.call_address(configuration, address)) if (interpreter.call_address(configuration, address) == Outcome::Return)
return Outcome::Return; return Outcome::Return;
configuration.regs = regs_copy; configuration.regs = regs_copy;
TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY)); TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY));
@ -1095,7 +1091,7 @@ HANDLE_INSTRUCTION(synthetic_call_10)
auto index = instruction->arguments().get<FunctionIndex>(); auto index = instruction->arguments().get<FunctionIndex>();
auto address = configuration.frame().module().functions()[index.value()]; auto address = configuration.frame().module().functions()[index.value()];
dbgln_if(WASM_TRACE_DEBUG, "[{}] call(#{} -> {})", current_ip_value, index.value(), address.value()); dbgln_if(WASM_TRACE_DEBUG, "[{}] call(#{} -> {})", current_ip_value, index.value(), address.value());
if (interpreter.call_address(configuration, address)) if (interpreter.call_address(configuration, address) == Outcome::Return)
return Outcome::Return; return Outcome::Return;
configuration.regs = regs_copy; configuration.regs = regs_copy;
TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY)); TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY));
@ -1107,7 +1103,7 @@ HANDLE_INSTRUCTION(synthetic_call_11)
auto index = instruction->arguments().get<FunctionIndex>(); auto index = instruction->arguments().get<FunctionIndex>();
auto address = configuration.frame().module().functions()[index.value()]; auto address = configuration.frame().module().functions()[index.value()];
dbgln_if(WASM_TRACE_DEBUG, "[{}] call(#{} -> {})", current_ip_value, index.value(), address.value()); dbgln_if(WASM_TRACE_DEBUG, "[{}] call(#{} -> {})", current_ip_value, index.value(), address.value());
if (interpreter.call_address(configuration, address)) if (interpreter.call_address(configuration, address) == Outcome::Return)
return Outcome::Return; return Outcome::Return;
configuration.regs = regs_copy; configuration.regs = regs_copy;
TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY)); TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY));
@ -1119,7 +1115,7 @@ HANDLE_INSTRUCTION(synthetic_call_20)
auto index = instruction->arguments().get<FunctionIndex>(); auto index = instruction->arguments().get<FunctionIndex>();
auto address = configuration.frame().module().functions()[index.value()]; auto address = configuration.frame().module().functions()[index.value()];
dbgln_if(WASM_TRACE_DEBUG, "[{}] call(#{} -> {})", current_ip_value, index.value(), address.value()); dbgln_if(WASM_TRACE_DEBUG, "[{}] call(#{} -> {})", current_ip_value, index.value(), address.value());
if (interpreter.call_address(configuration, address)) if (interpreter.call_address(configuration, address) == Outcome::Return)
return Outcome::Return; return Outcome::Return;
configuration.regs = regs_copy; configuration.regs = regs_copy;
TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY)); TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY));
@ -1131,7 +1127,7 @@ HANDLE_INSTRUCTION(synthetic_call_21)
auto index = instruction->arguments().get<FunctionIndex>(); auto index = instruction->arguments().get<FunctionIndex>();
auto address = configuration.frame().module().functions()[index.value()]; auto address = configuration.frame().module().functions()[index.value()];
dbgln_if(WASM_TRACE_DEBUG, "[{}] call(#{} -> {})", current_ip_value, index.value(), address.value()); dbgln_if(WASM_TRACE_DEBUG, "[{}] call(#{} -> {})", current_ip_value, index.value(), address.value());
if (interpreter.call_address(configuration, address)) if (interpreter.call_address(configuration, address) == Outcome::Return)
return Outcome::Return; return Outcome::Return;
configuration.regs = regs_copy; configuration.regs = regs_copy;
TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY)); TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY));
@ -1143,7 +1139,7 @@ HANDLE_INSTRUCTION(synthetic_call_30)
auto index = instruction->arguments().get<FunctionIndex>(); auto index = instruction->arguments().get<FunctionIndex>();
auto address = configuration.frame().module().functions()[index.value()]; auto address = configuration.frame().module().functions()[index.value()];
dbgln_if(WASM_TRACE_DEBUG, "[{}] call(#{} -> {})", current_ip_value, index.value(), address.value()); dbgln_if(WASM_TRACE_DEBUG, "[{}] call(#{} -> {})", current_ip_value, index.value(), address.value());
if (interpreter.call_address(configuration, address)) if (interpreter.call_address(configuration, address) == Outcome::Return)
return Outcome::Return; return Outcome::Return;
configuration.regs = regs_copy; configuration.regs = regs_copy;
TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY)); TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY));
@ -1155,7 +1151,7 @@ HANDLE_INSTRUCTION(synthetic_call_31)
auto index = instruction->arguments().get<FunctionIndex>(); auto index = instruction->arguments().get<FunctionIndex>();
auto address = configuration.frame().module().functions()[index.value()]; auto address = configuration.frame().module().functions()[index.value()];
dbgln_if(WASM_TRACE_DEBUG, "[{}] call(#{} -> {})", current_ip_value, index.value(), address.value()); dbgln_if(WASM_TRACE_DEBUG, "[{}] call(#{} -> {})", current_ip_value, index.value(), address.value());
if (interpreter.call_address(configuration, address)) if (interpreter.call_address(configuration, address) == Outcome::Return)
return Outcome::Return; return Outcome::Return;
configuration.regs = regs_copy; configuration.regs = regs_copy;
TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY)); TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY));
@ -1321,11 +1317,31 @@ HANDLE_INSTRUCTION(call)
auto index = instruction->arguments().get<FunctionIndex>(); auto index = instruction->arguments().get<FunctionIndex>();
auto address = configuration.frame().module().functions()[index.value()]; auto address = configuration.frame().module().functions()[index.value()];
dbgln_if(WASM_TRACE_DEBUG, "call({})", address.value()); dbgln_if(WASM_TRACE_DEBUG, "call({})", address.value());
if (interpreter.call_address(configuration, address)) if (interpreter.call_address(configuration, address) == Outcome::Return)
return Outcome::Return; return Outcome::Return;
TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY)); TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY));
} }
HANDLE_INSTRUCTION(return_call)
{
auto index = instruction->arguments().get<FunctionIndex>();
auto address = configuration.frame().module().functions()[index.value()];
configuration.label_stack().shrink(configuration.frame().label_index() + 1, true);
dbgln_if(WASM_TRACE_DEBUG, "tail call({})", address.value());
switch (auto const outcome = interpreter.call_address(configuration, address, BytecodeInterpreter::CallAddressSource::DirectTailCall)) {
default:
// Some IP we have to continue from.
current_ip_value = to_underlying(outcome) - 1;
addresses = { .sources_and_destination = default_sources_and_destination };
cc = configuration.frame().expression().compiled_instructions.dispatches.data();
[[fallthrough]];
case Outcome::Continue:
TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY));
case Outcome::Return:
return Outcome::Return;
}
}
HANDLE_INSTRUCTION(call_indirect) HANDLE_INSTRUCTION(call_indirect)
{ {
auto& args = instruction->arguments().get<Instruction::IndirectCallArgs>(); auto& args = instruction->arguments().get<Instruction::IndirectCallArgs>();
@ -1346,11 +1362,45 @@ HANDLE_INSTRUCTION(call_indirect)
TRAP_IN_LOOP_IF_NOT(type_actual.results() == type_expected.results()); TRAP_IN_LOOP_IF_NOT(type_actual.results() == type_expected.results());
dbgln_if(WASM_TRACE_DEBUG, "call_indirect({} -> {})", index, address.value()); dbgln_if(WASM_TRACE_DEBUG, "call_indirect({} -> {})", index, address.value());
if (interpreter.call_address(configuration, address, BytecodeInterpreter::CallAddressSource::IndirectCall)) if (interpreter.call_address(configuration, address, BytecodeInterpreter::CallAddressSource::IndirectCall) == Outcome::Return)
return Outcome::Return; return Outcome::Return;
TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY)); TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY));
} }
HANDLE_INSTRUCTION(return_call_indirect)
{
auto& args = instruction->arguments().get<Instruction::IndirectCallArgs>();
auto table_address = configuration.frame().module().tables()[args.table.value()];
auto table_instance = configuration.store().get(table_address);
// bounds checked by verifier.
auto index = configuration.take_source(0, addresses.sources).to<i32>();
TRAP_IN_LOOP_IF_NOT(index >= 0);
TRAP_IN_LOOP_IF_NOT(static_cast<size_t>(index) < table_instance->elements().size());
auto& element = table_instance->elements()[index];
TRAP_IN_LOOP_IF_NOT(element.ref().has<Reference::Func>());
auto address = element.ref().get<Reference::Func>().address;
auto const& type_actual = configuration.store().get(address)->visit([](auto& f) -> decltype(auto) { return f.type(); });
auto const& type_expected = configuration.frame().module().types()[args.type.value()];
TRAP_IN_LOOP_IF_NOT(type_actual.parameters().size() == type_expected.parameters().size());
TRAP_IN_LOOP_IF_NOT(type_actual.results().size() == type_expected.results().size());
TRAP_IN_LOOP_IF_NOT(type_actual.parameters() == type_expected.parameters());
TRAP_IN_LOOP_IF_NOT(type_actual.results() == type_expected.results());
dbgln_if(WASM_TRACE_DEBUG, "tail call_indirect({} -> {})", index, address.value());
switch (auto const outcome = interpreter.call_address(configuration, address, BytecodeInterpreter::CallAddressSource::IndirectTailCall)) {
default:
// Some IP we have to continue from.
current_ip_value = to_underlying(outcome) - 1;
addresses = { .sources_and_destination = default_sources_and_destination };
cc = configuration.frame().expression().compiled_instructions.dispatches.data();
[[fallthrough]];
case Outcome::Continue:
TAILCALL return continue_(HANDLER_PARAMS(DECOMPOSE_PARAMS_NAME_ONLY));
case Outcome::Return:
return Outcome::Return;
}
}
HANDLE_INSTRUCTION(i32_load) HANDLE_INSTRUCTION(i32_load)
{ {
if (interpreter.load_and_push<i32, i32>(configuration, *instruction, addresses)) if (interpreter.load_and_push<i32, i32>(configuration, *instruction, addresses))
@ -3633,10 +3683,9 @@ FLATTEN void BytecodeInterpreter::interpret_impl(Configuration& configuration, E
auto current_ip_value = configuration.ip(); auto current_ip_value = configuration.ip();
u64 executed_instructions = 0; u64 executed_instructions = 0;
constexpr static u32 default_sources_and_destination = (to_underlying(Dispatch::RegisterOrStack::Stack) | (to_underlying(Dispatch::RegisterOrStack::Stack) << 2) | (to_underlying(Dispatch::RegisterOrStack::Stack) << 4));
SourcesAndDestination addresses { .sources_and_destination = default_sources_and_destination }; SourcesAndDestination addresses { .sources_and_destination = default_sources_and_destination };
auto const cc = expression.compiled_instructions.dispatches.data(); auto cc = expression.compiled_instructions.dispatches.data();
if constexpr (HaveDirectThreadingInfo) { if constexpr (HaveDirectThreadingInfo) {
static_assert(HasCompiledList, "Direct threading requires a compiled instruction list"); static_assert(HasCompiledList, "Direct threading requires a compiled instruction list");
@ -3678,6 +3727,8 @@ FLATTEN void BytecodeInterpreter::interpret_impl(Configuration& configuration, E
if (outcome == Outcome::Return) \ if (outcome == Outcome::Return) \
return; \ return; \
current_ip_value = to_underlying(outcome); \ current_ip_value = to_underlying(outcome); \
if constexpr (Instructions::name == Instructions::return_call || Instructions::name == Instructions::return_call_indirect) \
cc = configuration.frame().expression().compiled_instructions.dispatches.data(); \
RUN_NEXT_INSTRUCTION(); \ RUN_NEXT_INSTRUCTION(); \
} }
@ -3869,14 +3920,14 @@ VectorType BytecodeInterpreter::pop_vector(Configuration& configuration, size_t
return bit_cast<VectorType>(configuration.take_source(source, addresses.sources).to<u128>()); return bit_cast<VectorType>(configuration.take_source(source, addresses.sources).to<u128>());
} }
bool BytecodeInterpreter::call_address(Configuration& configuration, FunctionAddress address, CallAddressSource source) Outcome BytecodeInterpreter::call_address(Configuration& configuration, FunctionAddress address, CallAddressSource source)
{ {
TRAP_IF_NOT(m_stack_info.size_free() >= Constants::minimum_stack_space_to_keep_free, "{}: {}", Constants::stack_exhaustion_message); TRAP_IF_NOT(m_stack_info.size_free() >= Constants::minimum_stack_space_to_keep_free, "{}: {}", Constants::stack_exhaustion_message);
auto instance = configuration.store().get(address); auto instance = configuration.store().get(address);
FunctionType const* type { nullptr }; FunctionType const* type { nullptr };
instance->visit([&](auto const& function) { type = &function.type(); }); instance->visit([&](auto const& function) { type = &function.type(); });
if (source == CallAddressSource::IndirectCall) { if (source == CallAddressSource::IndirectCall || source == CallAddressSource::IndirectTailCall) {
TRAP_IF_NOT(type->parameters().size() <= configuration.value_stack().size()); TRAP_IF_NOT(type->parameters().size() <= configuration.value_stack().size());
} }
Vector<Value> args; Vector<Value> args;
@ -3890,16 +3941,34 @@ bool BytecodeInterpreter::call_address(Configuration& configuration, FunctionAdd
} }
Result result { Trap::from_string("") }; Result result { Trap::from_string("") };
Outcome final_outcome = Outcome::Continue;
if (source == CallAddressSource::DirectTailCall || source == CallAddressSource::IndirectTailCall) {
auto prep_outcome = configuration.prepare_call(address, args, true);
if (prep_outcome.is_error()) {
m_trap = prep_outcome.release_error();
return Outcome::Return;
}
final_outcome = Outcome::Return; // At this point we can only ever return (unless we succeed in tail-calling).
if (prep_outcome.value().has_value()) {
result = prep_outcome.value()->function()(configuration, args);
} else {
configuration.ip() = 0;
return static_cast<Outcome>(0); // Continue from IP 0 in the new frame.
}
} else {
if (instance->has<WasmFunction>()) { if (instance->has<WasmFunction>()) {
CallFrameHandle handle { *this, configuration }; CallFrameHandle handle { *this, configuration };
result = configuration.call(*this, address, move(args)); result = configuration.call(*this, address, move(args));
} else { } else {
result = configuration.call(*this, address, move(args)); result = configuration.call(*this, address, move(args));
} }
}
if (result.is_trap()) { if (result.is_trap()) {
m_trap = move(result.trap()); m_trap = move(result.trap());
return true; return Outcome::Return;
} }
if (!result.values().is_empty()) { if (!result.values().is_empty()) {
@ -3908,7 +3977,7 @@ bool BytecodeInterpreter::call_address(Configuration& configuration, FunctionAdd
configuration.value_stack().unchecked_append(entry); configuration.value_stack().unchecked_append(entry);
} }
return false; return final_outcome;
} }
template<typename PopTypeLHS, typename PushType, typename Operator, typename PopTypeRHS, typename... Args> template<typename PopTypeLHS, typename PushType, typename Operator, typename PopTypeRHS, typename... Args>

View file

@ -21,6 +21,12 @@ union SourcesAndDestination {
u32 sources_and_destination; u32 sources_and_destination;
}; };
enum class Outcome : u64 {
// 0..Constants::max_allowed_executed_instructions_per_call -> next IP.
Continue = Constants::max_allowed_executed_instructions_per_call + 1,
Return,
};
struct WASM_API BytecodeInterpreter final : public Interpreter { struct WASM_API BytecodeInterpreter final : public Interpreter {
explicit BytecodeInterpreter(StackInfo const& stack_info) explicit BytecodeInterpreter(StackInfo const& stack_info)
: m_stack_info(stack_info) : m_stack_info(stack_info)
@ -59,6 +65,8 @@ struct WASM_API BytecodeInterpreter final : public Interpreter {
enum class CallAddressSource { enum class CallAddressSource {
DirectCall, DirectCall,
IndirectCall, IndirectCall,
DirectTailCall,
IndirectTailCall,
}; };
template<bool HasCompiledList, bool HasDynamicInsnLimit, bool HaveDirectThreadingInfo> template<bool HasCompiledList, bool HasDynamicInsnLimit, bool HaveDirectThreadingInfo>
@ -88,7 +96,7 @@ struct WASM_API BytecodeInterpreter final : public Interpreter {
template<typename M, template<typename> typename SetSign, typename VectorType = Native128ByteVectorOf<M, SetSign>> template<typename M, template<typename> typename SetSign, typename VectorType = Native128ByteVectorOf<M, SetSign>>
VectorType pop_vector(Configuration&, size_t source, SourcesAndDestination const&); VectorType pop_vector(Configuration&, size_t source, SourcesAndDestination const&);
bool store_to_memory(Configuration&, Instruction::MemoryArgument const&, ReadonlyBytes data, u32 base); bool store_to_memory(Configuration&, Instruction::MemoryArgument const&, ReadonlyBytes data, u32 base);
bool call_address(Configuration&, FunctionAddress, CallAddressSource = CallAddressSource::DirectCall); Outcome call_address(Configuration&, FunctionAddress, CallAddressSource = CallAddressSource::DirectCall);
template<typename T> template<typename T>
bool store_to_memory(MemoryInstance&, u64 address, T value); bool store_to_memory(MemoryInstance&, u64 address, T value);

View file

@ -11,7 +11,7 @@
namespace Wasm { namespace Wasm {
void Configuration::unwind(Badge<CallFrameHandle>, CallFrameHandle const&) void Configuration::unwind_impl()
{ {
m_frame_stack.take_last(); m_frame_stack.take_last();
m_depth--; m_depth--;
@ -19,11 +19,22 @@ void Configuration::unwind(Badge<CallFrameHandle>, CallFrameHandle const&)
} }
Result Configuration::call(Interpreter& interpreter, FunctionAddress address, Vector<Value> arguments) Result Configuration::call(Interpreter& interpreter, FunctionAddress address, Vector<Value> arguments)
{
if (auto fn = TRY(prepare_call(address, arguments)); fn.has_value())
return fn->function()(*this, arguments);
m_ip = 0;
return execute(interpreter);
}
ErrorOr<Optional<HostFunction&>, Trap> Configuration::prepare_call(FunctionAddress address, Vector<Value>& arguments, bool is_tailcall)
{ {
auto* function = m_store.get(address); auto* function = m_store.get(address);
if (!function) if (!function)
return Trap::from_string("Attempt to call nonexistent function by address"); return Trap::from_string("Attempt to call nonexistent function by address");
if (auto* wasm_function = function->get_pointer<WasmFunction>()) { if (auto* wasm_function = function->get_pointer<WasmFunction>()) {
if (is_tailcall)
unwind_impl(); // Unwind the current frame, the "return" in the tail-called function will unwind the frame we're gonna push now.
Vector<Value> locals = move(arguments); Vector<Value> locals = move(arguments);
locals.ensure_capacity(locals.size() + wasm_function->code().func().locals().size()); locals.ensure_capacity(locals.size() + wasm_function->code().func().locals().size());
for (auto& local : wasm_function->code().func().locals()) { for (auto& local : wasm_function->code().func().locals()) {
@ -36,14 +47,12 @@ Result Configuration::call(Interpreter& interpreter, FunctionAddress address, Ve
move(locals), move(locals),
wasm_function->code().func().body(), wasm_function->code().func().body(),
wasm_function->type().results().size(), wasm_function->type().results().size(),
}); },
m_ip = 0; is_tailcall);
return execute(interpreter); return OptionalNone {};
} }
// It better be a host function, else something is really wrong. return function->get<HostFunction>();
auto& host_function = function->get<HostFunction>();
return host_function.function()(*this, arguments);
} }
Result Configuration::execute(Interpreter& interpreter) Result Configuration::execute(Interpreter& interpreter)

View file

@ -19,7 +19,7 @@ public:
{ {
} }
void set_frame(Frame frame) void set_frame(Frame frame, bool is_tailcall = false)
{ {
auto continuation = frame.expression().instructions().size() - 1; auto continuation = frame.expression().instructions().size() - 1;
if (auto size = frame.expression().compiled_instructions.dispatches.size(); size > 0) if (auto size = frame.expression().compiled_instructions.dispatches.size(); size > 0)
@ -28,8 +28,10 @@ public:
frame.label_index() = m_label_stack.size(); frame.label_index() = m_label_stack.size();
if (auto hint = frame.expression().stack_usage_hint(); hint.has_value()) if (auto hint = frame.expression().stack_usage_hint(); hint.has_value())
m_value_stack.ensure_capacity(*hint + m_value_stack.size()); m_value_stack.ensure_capacity(*hint + m_value_stack.size());
if (!is_tailcall) {
if (auto hint = frame.expression().frame_usage_hint(); hint.has_value()) if (auto hint = frame.expression().frame_usage_hint(); hint.has_value())
m_label_stack.ensure_capacity(*hint + m_label_stack.size()); m_label_stack.ensure_capacity(*hint + m_label_stack.size());
}
m_frame_stack.append(move(frame)); m_frame_stack.append(move(frame));
m_label_stack.append(label); m_label_stack.append(label);
m_locals_base = m_frame_stack.unchecked_last().locals().data(); m_locals_base = m_frame_stack.unchecked_last().locals().data();
@ -66,7 +68,8 @@ public:
Configuration& configuration; Configuration& configuration;
}; };
void unwind(Badge<CallFrameHandle>, CallFrameHandle const&); void unwind(Badge<CallFrameHandle>, CallFrameHandle const&) { unwind_impl(); }
ErrorOr<Optional<HostFunction&>, Trap> prepare_call(FunctionAddress, Vector<Value>& arguments, bool is_tailcall = false);
Result call(Interpreter&, FunctionAddress, Vector<Value> arguments); Result call(Interpreter&, FunctionAddress, Vector<Value> arguments);
Result execute(Interpreter&); Result execute(Interpreter&);
@ -113,6 +116,8 @@ public:
}; };
private: private:
void unwind_impl();
Store& m_store; Store& m_store;
Vector<Value, 64, FastLastAccess::Yes> m_value_stack; Vector<Value, 64, FastLastAccess::Yes> m_value_stack;
Vector<Label, 64> m_label_stack; Vector<Label, 64> m_label_stack;

View file

@ -2142,6 +2142,52 @@ VALIDATE_INSTRUCTION(call_indirect)
return {}; return {};
} }
VALIDATE_INSTRUCTION(return_call)
{
auto index = instruction.arguments().get<FunctionIndex>();
TRY(validate(index));
auto& function_type = m_context.functions[index.value()];
for (size_t i = 0; i < function_type.parameters().size(); ++i)
TRY(stack.take(function_type.parameters()[function_type.parameters().size() - i - 1]));
auto const& return_types = m_frames.first().type.results();
if (return_types != function_type.results())
return Errors::invalid("return_call target"sv, function_type.results(), return_types);
m_frames.last().unreachable = true;
stack.resize(m_frames.last().initial_size);
return {};
}
VALIDATE_INSTRUCTION(return_call_indirect)
{
auto& args = instruction.arguments().get<Instruction::IndirectCallArgs>();
TRY(validate(args.table));
TRY(validate(args.type));
auto& table = m_context.tables[args.table.value()];
if (table.element_type().kind() != ValueType::FunctionReference)
return Errors::invalid("table element type for call.indirect"sv, "a function reference"sv, table.element_type());
auto& type = m_context.types[args.type.value()];
TRY(stack.take<ValueType::I32>());
for (size_t i = 0; i < type.parameters().size(); ++i)
TRY(stack.take(type.parameters()[type.parameters().size() - i - 1]));
auto& return_types = m_frames.first().type.results();
if (return_types != type.results())
return Errors::invalid("return_call_indirect target"sv, type.results(), return_types);
m_frames.last().unreachable = true;
stack.resize(m_frames.last().initial_size);
return {};
}
VALIDATE_INSTRUCTION(v128_load) VALIDATE_INSTRUCTION(v128_load)
{ {
auto& arg = instruction.arguments().get<Instruction::MemoryArgument>(); auto& arg = instruction.arguments().get<Instruction::MemoryArgument>();

View file

@ -29,6 +29,8 @@ namespace Instructions {
M(return_, 0x0f, -1, -1) \ M(return_, 0x0f, -1, -1) \
M(call, 0x10, -1, -1) \ M(call, 0x10, -1, -1) \
M(call_indirect, 0x11, -1, -1) \ M(call_indirect, 0x11, -1, -1) \
M(return_call, 0x12, -1, -1) \
M(return_call_indirect, 0x13, -1, -1) \
M(drop, 0x1a, 1, 0) \ M(drop, 0x1a, 1, 0) \
M(select, 0x1b, 3, 1) \ M(select, 0x1b, 3, 1) \
M(select_typed, 0x1c, 3, 1) \ M(select_typed, 0x1c, 3, 1) \

View file

@ -269,6 +269,17 @@ ParseResult<Instruction> Instruction::parse(ConstrainedStream& stream)
auto table_index = TRY(GenericIndexParser<TableIndex>::parse(stream)); auto table_index = TRY(GenericIndexParser<TableIndex>::parse(stream));
return Instruction { opcode, IndirectCallArgs { type_index, table_index } }; return Instruction { opcode, IndirectCallArgs { type_index, table_index } };
} }
case Instructions::return_call.value(): {
// return_call function
auto function_index = TRY(GenericIndexParser<FunctionIndex>::parse(stream));
return Instruction { opcode, function_index };
}
case Instructions::return_call_indirect.value(): {
// return_call_indirect type table
auto type_index = TRY(GenericIndexParser<TypeIndex>::parse(stream));
auto table_index = TRY(GenericIndexParser<TableIndex>::parse(stream));
return Instruction { opcode, IndirectCallArgs { type_index, table_index } };
}
case Instructions::i32_load.value(): case Instructions::i32_load.value():
case Instructions::i64_load.value(): case Instructions::i64_load.value():
case Instructions::f32_load.value(): case Instructions::f32_load.value():

View file

@ -721,7 +721,9 @@ HashMap<Wasm::OpCode, ByteString> Wasm::Names::instruction_names {
{ Instructions::br_table, "br.table" }, { Instructions::br_table, "br.table" },
{ Instructions::return_, "return" }, { Instructions::return_, "return" },
{ Instructions::call, "call" }, { Instructions::call, "call" },
{ Instructions::return_call, "return_call" },
{ Instructions::call_indirect, "call.indirect" }, { Instructions::call_indirect, "call.indirect" },
{ Instructions::return_call_indirect, "return_call.indirect" },
{ Instructions::drop, "drop" }, { Instructions::drop, "drop" },
{ Instructions::select, "select" }, { Instructions::select, "select" },
{ Instructions::select_typed, "select.typed" }, { Instructions::select_typed, "select.typed" },

View file

@ -197,7 +197,7 @@ def escape(s: str) -> str:
def make_description(input_path: Path, name: str, out_path: Path) -> WastDescription: def make_description(input_path: Path, name: str, out_path: Path) -> WastDescription:
out_json_path = out_path / f"{name}.json" out_json_path = out_path / f"{name}.json"
result = subprocess.run( result = subprocess.run(
["wast2json", input_path, f"--output={out_json_path}", "--no-check"], ["wast2json", "--enable-all", input_path, f"--output={out_json_path}", "--no-check"],
) )
result.check_returncode() result.check_returncode()
with open(out_json_path, "r") as f: with open(out_json_path, "r") as f: