gh-130415: Narrow types to constants in branches involving specialized comparisons with a constant (GH-144150)

This commit is contained in:
reiden 2026-01-24 18:02:08 +08:00 committed by GitHub
parent 29840247ff
commit 6d972e0104
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 360 additions and 8 deletions

View file

@ -250,6 +250,11 @@ add_op(JitOptContext *ctx, _PyUOpInstruction *this_instr,
#define sym_new_predicate _Py_uop_sym_new_predicate
#define sym_apply_predicate_narrowing _Py_uop_sym_apply_predicate_narrowing
/* Comparison oparg masks */
#define COMPARE_LT_MASK 2
#define COMPARE_GT_MASK 4
#define COMPARE_EQ_MASK 8
#define JUMP_TO_LABEL(label) goto label;
static int

View file

@ -521,21 +521,51 @@ dummy_func(void) {
}
op(_COMPARE_OP_INT, (left, right -- res, l, r)) {
res = sym_new_type(ctx, &PyBool_Type);
int cmp_mask = oparg & (COMPARE_LT_MASK | COMPARE_GT_MASK | COMPARE_EQ_MASK);
if (cmp_mask == COMPARE_EQ_MASK) {
res = sym_new_predicate(ctx, left, right, JIT_PRED_EQ);
}
else if (cmp_mask == (COMPARE_LT_MASK | COMPARE_GT_MASK)) {
res = sym_new_predicate(ctx, left, right, JIT_PRED_NE);
}
else {
res = sym_new_type(ctx, &PyBool_Type);
}
l = left;
r = right;
REPLACE_OPCODE_IF_EVALUATES_PURE(left, right, res);
}
op(_COMPARE_OP_FLOAT, (left, right -- res, l, r)) {
res = sym_new_type(ctx, &PyBool_Type);
int cmp_mask = oparg & (COMPARE_LT_MASK | COMPARE_GT_MASK | COMPARE_EQ_MASK);
if (cmp_mask == COMPARE_EQ_MASK) {
res = sym_new_predicate(ctx, left, right, JIT_PRED_EQ);
}
else if (cmp_mask == (COMPARE_LT_MASK | COMPARE_GT_MASK)) {
res = sym_new_predicate(ctx, left, right, JIT_PRED_NE);
}
else {
res = sym_new_type(ctx, &PyBool_Type);
}
l = left;
r = right;
REPLACE_OPCODE_IF_EVALUATES_PURE(left, right, res);
}
op(_COMPARE_OP_STR, (left, right -- res, l, r)) {
res = sym_new_type(ctx, &PyBool_Type);
int cmp_mask = oparg & (COMPARE_LT_MASK | COMPARE_GT_MASK | COMPARE_EQ_MASK);
if (cmp_mask == COMPARE_EQ_MASK) {
res = sym_new_predicate(ctx, left, right, JIT_PRED_EQ);
}
else if (cmp_mask == (COMPARE_LT_MASK | COMPARE_GT_MASK)) {
res = sym_new_predicate(ctx, left, right, JIT_PRED_NE);
}
else {
res = sym_new_type(ctx, &PyBool_Type);
}
l = left;
r = right;
REPLACE_OPCODE_IF_EVALUATES_PURE(left, right, res);

View file

@ -2118,7 +2118,16 @@
JitOptRef r;
right = stack_pointer[-1];
left = stack_pointer[-2];
res = sym_new_type(ctx, &PyBool_Type);
int cmp_mask = oparg & (COMPARE_LT_MASK | COMPARE_GT_MASK | COMPARE_EQ_MASK);
if (cmp_mask == COMPARE_EQ_MASK) {
res = sym_new_predicate(ctx, left, right, JIT_PRED_EQ);
}
else if (cmp_mask == (COMPARE_LT_MASK | COMPARE_GT_MASK)) {
res = sym_new_predicate(ctx, left, right, JIT_PRED_NE);
}
else {
res = sym_new_type(ctx, &PyBool_Type);
}
l = left;
r = right;
if (
@ -2178,7 +2187,16 @@
JitOptRef r;
right = stack_pointer[-1];
left = stack_pointer[-2];
res = sym_new_type(ctx, &PyBool_Type);
int cmp_mask = oparg & (COMPARE_LT_MASK | COMPARE_GT_MASK | COMPARE_EQ_MASK);
if (cmp_mask == COMPARE_EQ_MASK) {
res = sym_new_predicate(ctx, left, right, JIT_PRED_EQ);
}
else if (cmp_mask == (COMPARE_LT_MASK | COMPARE_GT_MASK)) {
res = sym_new_predicate(ctx, left, right, JIT_PRED_NE);
}
else {
res = sym_new_type(ctx, &PyBool_Type);
}
l = left;
r = right;
if (
@ -2242,7 +2260,16 @@
JitOptRef r;
right = stack_pointer[-1];
left = stack_pointer[-2];
res = sym_new_type(ctx, &PyBool_Type);
int cmp_mask = oparg & (COMPARE_LT_MASK | COMPARE_GT_MASK | COMPARE_EQ_MASK);
if (cmp_mask == COMPARE_EQ_MASK) {
res = sym_new_predicate(ctx, left, right, JIT_PRED_EQ);
}
else if (cmp_mask == (COMPARE_LT_MASK | COMPARE_GT_MASK)) {
res = sym_new_predicate(ctx, left, right, JIT_PRED_NE);
}
else {
res = sym_new_type(ctx, &PyBool_Type);
}
l = left;
r = right;
if (

View file

@ -875,9 +875,11 @@ _Py_uop_sym_apply_predicate_narrowing(JitOptContext *ctx, JitOptRef ref, bool br
bool narrow = false;
switch(pred.kind) {
case JIT_PRED_EQ:
case JIT_PRED_IS:
narrow = branch_is_true;
break;
case JIT_PRED_NE:
case JIT_PRED_IS_NOT:
narrow = !branch_is_true;
break;
@ -1300,11 +1302,11 @@ _Py_uop_symbols_test(PyObject *Py_UNUSED(self), PyObject *Py_UNUSED(ignored))
TEST_PREDICATE(_Py_uop_sym_is_const(ctx, subject), "predicate narrowing did not const-narrow subject (None)");
TEST_PREDICATE(_Py_uop_sym_get_const(ctx, subject) == Py_None, "predicate narrowing did not narrow subject to None");
// Test narrowing subject to numerical constant
// Test narrowing subject to numerical constant from is comparison
subject = _Py_uop_sym_new_unknown(ctx);
PyObject *one_obj = PyLong_FromLong(1);
JitOptRef const_one = _Py_uop_sym_new_const(ctx, one_obj);
if (PyJitRef_IsNull(subject) || PyJitRef_IsNull(const_one)) {
if (PyJitRef_IsNull(subject) || one_obj == NULL || PyJitRef_IsNull(const_one)) {
goto fail;
}
ref = _Py_uop_sym_new_predicate(ctx, subject, const_one, JIT_PRED_IS);
@ -1315,6 +1317,160 @@ _Py_uop_symbols_test(PyObject *Py_UNUSED(self), PyObject *Py_UNUSED(ignored))
TEST_PREDICATE(_Py_uop_sym_is_const(ctx, subject), "predicate narrowing did not const-narrow subject (1)");
TEST_PREDICATE(_Py_uop_sym_get_const(ctx, subject) == one_obj, "predicate narrowing did not narrow subject to 1");
// Test narrowing subject to constant from EQ predicate for int
subject = _Py_uop_sym_new_unknown(ctx);
if (PyJitRef_IsNull(subject)) {
goto fail;
}
ref = _Py_uop_sym_new_predicate(ctx, subject, const_one, JIT_PRED_EQ);
if (PyJitRef_IsNull(ref)) {
goto fail;
}
_Py_uop_sym_apply_predicate_narrowing(ctx, ref, true);
TEST_PREDICATE(_Py_uop_sym_is_const(ctx, subject), "predicate narrowing did not const-narrow subject (1)");
TEST_PREDICATE(_Py_uop_sym_get_const(ctx, subject) == one_obj, "predicate narrowing did not narrow subject to 1");
// Resolving EQ predicate to False should not narrow subject for int
subject = _Py_uop_sym_new_unknown(ctx);
if (PyJitRef_IsNull(subject)) {
goto fail;
}
ref = _Py_uop_sym_new_predicate(ctx, subject, const_one, JIT_PRED_EQ);
if (PyJitRef_IsNull(ref)) {
goto fail;
}
_Py_uop_sym_apply_predicate_narrowing(ctx, ref, false);
TEST_PREDICATE(!_Py_uop_sym_is_const(ctx, subject), "predicate narrowing incorrectly narrowed subject (inverted/true)");
// Test narrowing subject to constant from NE predicate for int
subject = _Py_uop_sym_new_unknown(ctx);
if (PyJitRef_IsNull(subject)) {
goto fail;
}
ref = _Py_uop_sym_new_predicate(ctx, subject, const_one, JIT_PRED_NE);
if (PyJitRef_IsNull(ref)) {
goto fail;
}
_Py_uop_sym_apply_predicate_narrowing(ctx, ref, false);
TEST_PREDICATE(_Py_uop_sym_is_const(ctx, subject), "predicate narrowing did not const-narrow subject (1)");
TEST_PREDICATE(_Py_uop_sym_get_const(ctx, subject) == one_obj, "predicate narrowing did not narrow subject to 1");
// Resolving NE predicate to true should not narrow subject for int
subject = _Py_uop_sym_new_unknown(ctx);
if (PyJitRef_IsNull(subject)) {
goto fail;
}
ref = _Py_uop_sym_new_predicate(ctx, subject, const_one, JIT_PRED_NE);
if (PyJitRef_IsNull(ref)) {
goto fail;
}
_Py_uop_sym_apply_predicate_narrowing(ctx, ref, true);
TEST_PREDICATE(!_Py_uop_sym_is_const(ctx, subject), "predicate narrowing incorrectly narrowed subject (inverted/true)");
// Test narrowing subject to constant from EQ predicate for float
subject = _Py_uop_sym_new_unknown(ctx);
PyObject *float_tenth_obj = PyFloat_FromDouble(0.1);
JitOptRef const_float_tenth = _Py_uop_sym_new_const(ctx, float_tenth_obj);
if (PyJitRef_IsNull(subject) || float_tenth_obj == NULL || PyJitRef_IsNull(const_float_tenth)) {
goto fail;
}
ref = _Py_uop_sym_new_predicate(ctx, subject, const_float_tenth, JIT_PRED_EQ);
if (PyJitRef_IsNull(ref)) {
goto fail;
}
_Py_uop_sym_apply_predicate_narrowing(ctx, ref, true);
TEST_PREDICATE(_Py_uop_sym_is_const(ctx, subject), "predicate narrowing did not const-narrow subject (float)");
TEST_PREDICATE(_Py_uop_sym_get_const(ctx, subject) == float_tenth_obj, "predicate narrowing did not narrow subject to 0.1");
// Resolving EQ predicate to False should not narrow subject for float
subject = _Py_uop_sym_new_unknown(ctx);
if (PyJitRef_IsNull(subject)) {
goto fail;
}
ref = _Py_uop_sym_new_predicate(ctx, subject, const_float_tenth, JIT_PRED_EQ);
if (PyJitRef_IsNull(ref)) {
goto fail;
}
_Py_uop_sym_apply_predicate_narrowing(ctx, ref, false);
TEST_PREDICATE(!_Py_uop_sym_is_const(ctx, subject), "predicate narrowing incorrectly narrowed subject (inverted/true)");
// Test narrowing subject to constant from NE predicate for float
subject = _Py_uop_sym_new_unknown(ctx);
if (PyJitRef_IsNull(subject)) {
goto fail;
}
ref = _Py_uop_sym_new_predicate(ctx, subject, const_float_tenth, JIT_PRED_NE);
if (PyJitRef_IsNull(ref)) {
goto fail;
}
_Py_uop_sym_apply_predicate_narrowing(ctx, ref, false);
TEST_PREDICATE(_Py_uop_sym_is_const(ctx, subject), "predicate narrowing did not const-narrow subject (float)");
TEST_PREDICATE(_Py_uop_sym_get_const(ctx, subject) == float_tenth_obj, "predicate narrowing did not narrow subject to 0.1");
// Resolving NE predicate to true should not narrow subject for float
subject = _Py_uop_sym_new_unknown(ctx);
if (PyJitRef_IsNull(subject)) {
goto fail;
}
ref = _Py_uop_sym_new_predicate(ctx, subject, const_float_tenth, JIT_PRED_NE);
if (PyJitRef_IsNull(ref)) {
goto fail;
}
_Py_uop_sym_apply_predicate_narrowing(ctx, ref, true);
TEST_PREDICATE(!_Py_uop_sym_is_const(ctx, subject), "predicate narrowing incorrectly narrowed subject (inverted/true)");
// Test narrowing subject to constant from EQ predicate for str
subject = _Py_uop_sym_new_unknown(ctx);
PyObject *str_hello_obj = PyUnicode_FromString("hello");
JitOptRef const_str_hello = _Py_uop_sym_new_const(ctx, str_hello_obj);
if (PyJitRef_IsNull(subject) || str_hello_obj == NULL || PyJitRef_IsNull(const_str_hello)) {
goto fail;
}
ref = _Py_uop_sym_new_predicate(ctx, subject, const_str_hello, JIT_PRED_EQ);
if (PyJitRef_IsNull(ref)) {
goto fail;
}
_Py_uop_sym_apply_predicate_narrowing(ctx, ref, true);
TEST_PREDICATE(_Py_uop_sym_is_const(ctx, subject), "predicate narrowing did not const-narrow subject (str)");
TEST_PREDICATE(_Py_uop_sym_get_const(ctx, subject) == str_hello_obj, "predicate narrowing did not narrow subject to hello");
// Resolving EQ predicate to False should not narrow subject for str
subject = _Py_uop_sym_new_unknown(ctx);
if (PyJitRef_IsNull(subject)) {
goto fail;
}
ref = _Py_uop_sym_new_predicate(ctx, subject, const_str_hello, JIT_PRED_EQ);
if (PyJitRef_IsNull(ref)) {
goto fail;
}
_Py_uop_sym_apply_predicate_narrowing(ctx, ref, false);
TEST_PREDICATE(!_Py_uop_sym_is_const(ctx, subject), "predicate narrowing incorrectly narrowed subject (inverted/true)");
// Test narrowing subject to constant from NE predicate for str
subject = _Py_uop_sym_new_unknown(ctx);
if (PyJitRef_IsNull(subject)) {
goto fail;
}
ref = _Py_uop_sym_new_predicate(ctx, subject, const_str_hello, JIT_PRED_NE);
if (PyJitRef_IsNull(ref)) {
goto fail;
}
_Py_uop_sym_apply_predicate_narrowing(ctx, ref, false);
TEST_PREDICATE(_Py_uop_sym_is_const(ctx, subject), "predicate narrowing did not const-narrow subject (str)");
TEST_PREDICATE(_Py_uop_sym_get_const(ctx, subject) == str_hello_obj, "predicate narrowing did not narrow subject to hello");
// Resolving NE predicate to true should not narrow subject for str
subject = _Py_uop_sym_new_unknown(ctx);
if (PyJitRef_IsNull(subject)) {
goto fail;
}
ref = _Py_uop_sym_new_predicate(ctx, subject, const_str_hello, JIT_PRED_NE);
if (PyJitRef_IsNull(ref)) {
goto fail;
}
_Py_uop_sym_apply_predicate_narrowing(ctx, ref, true);
TEST_PREDICATE(!_Py_uop_sym_is_const(ctx, subject), "predicate narrowing incorrectly narrowed subject (inverted/true)");
val_big = PyNumber_Lshift(_PyLong_GetOne(), PyLong_FromLong(66));
if (val_big == NULL) {
goto fail;