diff --git a/core/templates/command_queue_mt.h b/core/templates/command_queue_mt.h index 94a03878d40..7eaf2f94af1 100644 --- a/core/templates/command_queue_mt.h +++ b/core/templates/command_queue_mt.h @@ -39,6 +39,8 @@ #include "core/typedefs.h" class CommandQueueMT { + static const size_t MAX_COMMAND_SIZE = 1024; + struct CommandBase { bool sync = false; virtual void call() = 0; @@ -154,19 +156,28 @@ class CommandQueueMT { } void _flush() { + MutexLock lock(mutex); + if (unlikely(flush_read_ptr)) { // Re-entrant call. return; } - MutexLock lock(mutex); + char cmd_backup[MAX_COMMAND_SIZE]; while (flush_read_ptr < command_mem.size()) { uint64_t size = *(uint64_t *)&command_mem[flush_read_ptr]; - flush_read_ptr += 8; + flush_read_ptr += sizeof(uint64_t); + CommandBase *cmd = reinterpret_cast(&command_mem[flush_read_ptr]); + + // Protect against race condition between this thread + // during the call to the command and other threads potentially + // invalidating the pointer due to reallocs. + memcpy(cmd_backup, (char *)cmd, size); + uint32_t allowance_id = WorkerThreadPool::thread_enter_unlock_allowance_zone(lock); - cmd->call(); + ((CommandBase *)cmd_backup)->call(); WorkerThreadPool::thread_exit_unlock_allowance_zone(allowance_id); // Handle potential realloc due to the command and unlock allowance. @@ -174,9 +185,9 @@ class CommandQueueMT { if (unlikely(cmd->sync)) { sync_head++; - lock.~MutexLock(); // Give an opportunity to awaiters right away. + lock.temp_unlock(); // Give an opportunity to awaiters right away. sync_cond_var.notify_all(); - new (&lock) MutexLock(mutex); + lock.temp_relock(); // Handle potential realloc happened during unlock. cmd = reinterpret_cast(&command_mem[flush_read_ptr]); } @@ -210,6 +221,7 @@ public: void push(T *p_instance, M p_method, Args &&...p_args) { // Standard command, no sync. using CommandType = Command; + static_assert(sizeof(CommandType) <= MAX_COMMAND_SIZE); _push_internal(p_instance, p_method, std::forward(p_args)...); } @@ -217,6 +229,7 @@ public: void push_and_sync(T *p_instance, M p_method, Args... p_args) { // Standard command, sync. using CommandType = Command; + static_assert(sizeof(CommandType) <= MAX_COMMAND_SIZE); _push_internal(p_instance, p_method, std::forward(p_args)...); } @@ -224,6 +237,7 @@ public: void push_and_ret(T *p_instance, M p_method, R *r_ret, Args... p_args) { // Command with return value, sync. using CommandType = CommandRet; + static_assert(sizeof(CommandType) <= MAX_COMMAND_SIZE); _push_internal(p_instance, p_method, r_ret, std::forward(p_args)...); }