diff --git a/servers/rendering/shader_language.cpp b/servers/rendering/shader_language.cpp index 092ebff79d1..6e4d150ee15 100644 --- a/servers/rendering/shader_language.cpp +++ b/servers/rendering/shader_language.cpp @@ -1555,10 +1555,11 @@ bool ShaderLanguage::_find_identifier(const BlockNode *p_block, bool p_allow_rea return false; } -bool ShaderLanguage::_validate_operator(const BlockNode *p_block, const FunctionInfo &p_function_info, OperatorNode *p_op, DataType *r_ret_type, int *r_ret_size) { +bool ShaderLanguage::_validate_operator(const BlockNode *p_block, const FunctionInfo &p_function_info, OperatorNode *p_op, DataType *r_ret_type, int *r_ret_size, StringName *r_ret_struct_name) { bool valid = false; DataType ret_type = TYPE_VOID; int ret_size = 0; + String ret_struct_name; switch (p_op->op) { case OP_EQUAL: @@ -2003,7 +2004,23 @@ bool ShaderLanguage::_validate_operator(const BlockNode *p_block, const Function DataType nb = p_op->arguments[1]->get_datatype(); DataType nc = p_op->arguments[2]->get_datatype(); - valid = na == TYPE_BOOL && (nb == nc) && !is_sampler_type(nb); + bool is_same = false; + if (nb == nc) { + is_same = true; + + if (nb == TYPE_STRUCT) { + String tb = p_op->arguments[1]->get_datatype_name(); + String tc = p_op->arguments[2]->get_datatype_name(); + + if (tb != tc) { + break; + } + + ret_struct_name = tb; + } + } + + valid = na == TYPE_BOOL && is_same && !is_sampler_type(nb); ret_type = nb; ret_size = sa; } break; @@ -2018,6 +2035,9 @@ bool ShaderLanguage::_validate_operator(const BlockNode *p_block, const Function if (r_ret_size) { *r_ret_size = ret_size; } + if (r_ret_struct_name) { + *r_ret_struct_name = ret_struct_name; + } if (valid && (!p_block || p_block->use_op_eval)) { // Need to be placed here and not in the `_reduce_expression` because otherwise expressions like `1 + 2 / 2` will not work correctly. @@ -7699,7 +7719,7 @@ ShaderLanguage::Node *ShaderLanguage::_parse_expression(BlockNode *p_block, cons expression.write[next_op - 1].is_op = false; expression.write[next_op - 1].node = op; - if (!_validate_operator(p_block, p_function_info, op, &op->return_cache, &op->return_array_size)) { + if (!_validate_operator(p_block, p_function_info, op, &op->return_cache, &op->return_array_size, &op->struct_name)) { if (error_set) { return nullptr; } @@ -7709,7 +7729,12 @@ ShaderLanguage::Node *ShaderLanguage::_parse_expression(BlockNode *p_block, cons if (i > 0) { at += ", "; } - at += get_datatype_name(op->arguments[i]->get_datatype()); + DataType dt = op->arguments[i]->get_datatype(); + if (dt == TYPE_STRUCT) { + at += op->arguments[i]->get_datatype_name(); + } else { + at += get_datatype_name(dt); + } if (!op->arguments[i]->is_indexed() && op->arguments[i]->get_array_size() > 0) { at += "["; at += itos(op->arguments[i]->get_array_size()); diff --git a/servers/rendering/shader_language.h b/servers/rendering/shader_language.h index fed8550ab96..ce9244d9528 100644 --- a/servers/rendering/shader_language.h +++ b/servers/rendering/shader_language.h @@ -1132,7 +1132,7 @@ private: #endif // DEBUG_ENABLED bool _is_operator_assign(Operator p_op) const; bool _validate_assign(Node *p_node, const FunctionInfo &p_function_info, String *r_message = nullptr); - bool _validate_operator(const BlockNode *p_block, const FunctionInfo &p_function_info, OperatorNode *p_op, DataType *r_ret_type = nullptr, int *r_ret_size = nullptr); + bool _validate_operator(const BlockNode *p_block, const FunctionInfo &p_function_info, OperatorNode *p_op, DataType *r_ret_type = nullptr, int *r_ret_size = nullptr, StringName *r_ret_struct_name = nullptr); Vector _get_node_values(const BlockNode *p_block, const FunctionInfo &p_function_info, Node *p_node); bool _eval_operator(const BlockNode *p_block, const FunctionInfo &p_function_info, OperatorNode *p_op);