Merge pull request #111356 from blueskythlikesclouds/d3d12-spec-constant-patch-fix

Fix specialization constant patching on D3D12.
This commit is contained in:
Thaddeus Crews 2025-10-13 12:30:15 -05:00
commit 599fd7344a
No known key found for this signature in database
GPG key ID: 8C6E5FEB5FC03CCC

View file

@ -94,28 +94,26 @@ uint32_t RenderingDXIL::patch_specialization_constant(
const uint64_t (&p_stages_bit_offsets)[D3D12_BITCODE_OFFSETS_NUM_STAGES],
HashMap<RenderingDeviceCommons::ShaderStage, Vector<uint8_t>> &r_stages_bytecodes,
bool p_is_first_patch) {
uint32_t patch_val = 0;
int64_t patch_val = 0;
switch (p_type) {
case RenderingDeviceCommons::PIPELINE_SPECIALIZATION_CONSTANT_TYPE_INT: {
uint32_t int_value = *((const int *)p_value);
ERR_FAIL_COND_V(int_value & (1 << 31), 0);
patch_val = int_value;
patch_val = *((const int32_t *)p_value);
} break;
case RenderingDeviceCommons::PIPELINE_SPECIALIZATION_CONSTANT_TYPE_BOOL: {
bool bool_value = *((const bool *)p_value);
patch_val = (uint32_t)bool_value;
patch_val = (int32_t)bool_value;
} break;
case RenderingDeviceCommons::PIPELINE_SPECIALIZATION_CONSTANT_TYPE_FLOAT: {
uint32_t int_value = *((const int *)p_value);
ERR_FAIL_COND_V(int_value & (1 << 31), 0);
patch_val = (int_value >> 1);
patch_val = *((const int32_t *)p_value);
} break;
}
// For VBR encoding to encode the number of bits we expect (32), we need to set the MSB unconditionally.
// However, signed VBR moves the MSB to the LSB, so setting the MSB to 1 wouldn't help. Therefore,
// the bit we set to 1 is the one at index 30.
patch_val |= (1 << 30);
patch_val <<= 1; // What signed VBR does.
// Encode to signed VBR.
if (patch_val >= 0) {
patch_val <<= 1;
} else {
patch_val = ((-patch_val) << 1) | 1;
}
auto tamper_bits = [](uint8_t *p_start, uint64_t p_bit_offset, uint64_t p_tb_value) -> uint64_t {
uint64_t original = 0;
@ -169,13 +167,13 @@ uint32_t RenderingDXIL::patch_specialization_constant(
Vector<uint8_t> &bytecode = r_stages_bytecodes[(RenderingDeviceCommons::ShaderStage)stage];
#ifdef DEV_ENABLED
uint64_t orig_patch_val = tamper_bits(bytecode.ptrw(), offset, patch_val);
uint64_t orig_patch_val = tamper_bits(bytecode.ptrw(), offset, (uint64_t)patch_val);
// Checking against the value the NIR patch should have set.
DEV_ASSERT(!p_is_first_patch || ((orig_patch_val >> 1) & GODOT_NIR_SC_SENTINEL_MAGIC_MASK) == GODOT_NIR_SC_SENTINEL_MAGIC);
uint64_t readback_patch_val = tamper_bits(bytecode.ptrw(), offset, patch_val);
DEV_ASSERT(readback_patch_val == patch_val);
uint64_t readback_patch_val = tamper_bits(bytecode.ptrw(), offset, (uint64_t)patch_val);
DEV_ASSERT(readback_patch_val == (uint64_t)patch_val);
#else
tamper_bits(bytecode.ptrw(), offset, patch_val);
tamper_bits(bytecode.ptrw(), offset, (uint64_t)patch_val);
#endif
stages_patched_mask |= (1 << stage);