GH-116422: Tier2 hot/cold splitting (GH-116813)

Splits the "cold" path, deopts and exits, from the "hot" path, reducing the size of most jitted instructions, at the cost of slower exits.
This commit is contained in:
Mark Shannon 2024-03-26 09:35:11 +00:00 committed by GitHub
parent 61599a48f5
commit bf82f77957
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 1662 additions and 1003 deletions

View file

@ -8,7 +8,8 @@
@dataclass
class Properties:
escapes: bool
infallible: bool
error_with_pop: bool
error_without_pop: bool
deopts: bool
oparg: bool
jumps: bool
@ -37,7 +38,8 @@ def dump(self, indent: str) -> None:
def from_list(properties: list["Properties"]) -> "Properties":
return Properties(
escapes=any(p.escapes for p in properties),
infallible=all(p.infallible for p in properties),
error_with_pop=any(p.error_with_pop for p in properties),
error_without_pop=any(p.error_without_pop for p in properties),
deopts=any(p.deopts for p in properties),
oparg=any(p.oparg for p in properties),
jumps=any(p.jumps for p in properties),
@ -55,10 +57,16 @@ def from_list(properties: list["Properties"]) -> "Properties":
passthrough=all(p.passthrough for p in properties),
)
@property
def infallible(self) -> bool:
return not self.error_with_pop and not self.error_without_pop
SKIP_PROPERTIES = Properties(
escapes=False,
infallible=True,
error_with_pop=False,
error_without_pop=False,
deopts=False,
oparg=False,
jumps=False,
@ -157,20 +165,32 @@ def size(self) -> int:
self._size = sum(c.size for c in self.caches)
return self._size
def is_viable(self) -> bool:
def why_not_viable(self) -> str | None:
if self.name == "_SAVE_RETURN_OFFSET":
return True # Adjusts next_instr, but only in tier 1 code
if self.properties.needs_this:
return False
return None # Adjusts next_instr, but only in tier 1 code
if "INSTRUMENTED" in self.name:
return False
return "is instrumented"
if "replaced" in self.annotations:
return False
return "is replaced"
if self.name in ("INTERPRETER_EXIT", "JUMP_BACKWARD"):
return False
return "has tier 1 control flow"
if self.properties.needs_this:
return "uses the 'this_instr' variable"
if len([c for c in self.caches if c.name != "unused"]) > 1:
return False
return True
return "has unused cache entries"
if self.properties.error_with_pop and self.properties.error_without_pop:
return "has both popping and not-popping errors"
if self.properties.eval_breaker:
if self.properties.error_with_pop or self.properties.error_without_pop:
return "has error handling and eval-breaker check"
if self.properties.side_exit:
return "exits and eval-breaker check"
if self.properties.deopts:
return "deopts and eval-breaker check"
return None
def is_viable(self) -> bool:
return self.why_not_viable() is None
def is_super(self) -> bool:
for tkn in self.body:
@ -320,10 +340,17 @@ def tier_variable(node: parser.InstDef) -> int | None:
return int(token.text[-1])
return None
def is_infallible(op: parser.InstDef) -> bool:
return not (
def has_error_with_pop(op: parser.InstDef) -> bool:
return (
variable_used(op, "ERROR_IF")
or variable_used(op, "error")
or variable_used(op, "pop_1_error")
or variable_used(op, "exception_unwind")
or variable_used(op, "resume_with_error")
)
def has_error_without_pop(op: parser.InstDef) -> bool:
return (
variable_used(op, "ERROR_NO_POP")
or variable_used(op, "pop_1_error")
or variable_used(op, "exception_unwind")
or variable_used(op, "resume_with_error")
@ -507,12 +534,15 @@ def compute_properties(op: parser.InstDef) -> Properties:
tkn.column,
op.name,
)
infallible = is_infallible(op)
error_with_pop = has_error_with_pop(op)
error_without_pop = has_error_without_pop(op)
infallible = not error_with_pop and not error_without_pop
passthrough = stack_effect_only_peeks(op) and infallible
return Properties(
escapes=makes_escaping_api_call(op),
infallible=infallible,
deopts=deopts_if or exits_if,
error_with_pop=error_with_pop,
error_without_pop=error_without_pop,
deopts=deopts_if,
side_exit=exits_if,
oparg=variable_used(op, "oparg"),
jumps=variable_used(op, "JUMPBY"),

View file

@ -99,6 +99,20 @@ def replace_error(
out.emit(close)
def replace_error_no_pop(
out: CWriter,
tkn: Token,
tkn_iter: Iterator[Token],
uop: Uop,
stack: Stack,
inst: Instruction | None,
) -> None:
next(tkn_iter) # LPAREN
next(tkn_iter) # RPAREN
next(tkn_iter) # Semi colon
out.emit_at("goto error;", tkn)
def replace_decrefs(
out: CWriter,
tkn: Token,
@ -160,6 +174,7 @@ def replace_check_eval_breaker(
"EXIT_IF": replace_deopt,
"DEOPT_IF": replace_deopt,
"ERROR_IF": replace_error,
"ERROR_NO_POP": replace_error_no_pop,
"DECREF_INPUTS": replace_decrefs,
"CHECK_EVAL_BREAKER": replace_check_eval_breaker,
"SYNC_SP": replace_sync_sp,
@ -213,6 +228,8 @@ def cflags(p: Properties) -> str:
flags.append("HAS_EXIT_FLAG")
if not p.infallible:
flags.append("HAS_ERROR_FLAG")
if p.error_without_pop:
flags.append("HAS_ERROR_NO_POP_FLAG")
if p.escapes:
flags.append("HAS_ESCAPES_FLAG")
if p.pure:

View file

@ -54,6 +54,7 @@
"PURE",
"PASSTHROUGH",
"OPARG_AND_1",
"ERROR_NO_POP",
]

View file

@ -72,22 +72,22 @@ def tier2_replace_error(
label = next(tkn_iter).text
next(tkn_iter) # RPAREN
next(tkn_iter) # Semi colon
out.emit(") ")
c_offset = stack.peek_offset.to_c()
try:
offset = -int(c_offset)
close = ";\n"
except ValueError:
offset = None
out.emit(f"{{ stack_pointer += {c_offset}; ")
close = "; }\n"
out.emit("goto ")
if offset:
out.emit(f"pop_{offset}_")
out.emit(label + "_tier_two")
out.emit(close)
out.emit(") JUMP_TO_ERROR();\n")
def tier2_replace_error_no_pop(
out: CWriter,
tkn: Token,
tkn_iter: Iterator[Token],
uop: Uop,
stack: Stack,
inst: Instruction | None,
) -> None:
next(tkn_iter) # LPAREN
next(tkn_iter) # RPAREN
next(tkn_iter) # Semi colon
out.emit_at("JUMP_TO_ERROR();", tkn)
def tier2_replace_deopt(
out: CWriter,
tkn: Token,
@ -100,7 +100,7 @@ def tier2_replace_deopt(
out.emit(next(tkn_iter))
emit_to(out, tkn_iter, "RPAREN")
next(tkn_iter) # Semi colon
out.emit(") goto deoptimize;\n")
out.emit(") JUMP_TO_JUMP_TARGET();\n")
def tier2_replace_exit_if(
@ -115,7 +115,7 @@ def tier2_replace_exit_if(
out.emit(next(tkn_iter))
emit_to(out, tkn_iter, "RPAREN")
next(tkn_iter) # Semi colon
out.emit(") goto side_exit;\n")
out.emit(") JUMP_TO_JUMP_TARGET();\n")
def tier2_replace_oparg(
@ -141,6 +141,7 @@ def tier2_replace_oparg(
TIER2_REPLACEMENT_FUNCTIONS = REPLACEMENT_FUNCTIONS.copy()
TIER2_REPLACEMENT_FUNCTIONS["ERROR_IF"] = tier2_replace_error
TIER2_REPLACEMENT_FUNCTIONS["ERROR_NO_POP"] = tier2_replace_error_no_pop
TIER2_REPLACEMENT_FUNCTIONS["DEOPT_IF"] = tier2_replace_deopt
TIER2_REPLACEMENT_FUNCTIONS["oparg"] = tier2_replace_oparg
TIER2_REPLACEMENT_FUNCTIONS["EXIT_IF"] = tier2_replace_exit_if
@ -201,8 +202,9 @@ def generate_tier2(
continue
if uop.is_super():
continue
if not uop.is_viable():
out.emit(f"/* {uop.name} is not a viable micro-op for tier 2 */\n\n")
why_not_viable = uop.why_not_viable()
if why_not_viable is not None:
out.emit(f"/* {uop.name} is not a viable micro-op for tier 2 because it {why_not_viable} */\n\n")
continue
out.emit(f"case {uop.name}: {{\n")
declare_variables(uop, out)

View file

@ -15,10 +15,10 @@
write_header,
cflags,
)
from stack import Stack
from cwriter import CWriter
from typing import TextIO
DEFAULT_OUTPUT = ROOT / "Include/internal/pycore_uop_metadata.h"
@ -26,6 +26,7 @@ def generate_names_and_flags(analysis: Analysis, out: CWriter) -> None:
out.emit("extern const uint16_t _PyUop_Flags[MAX_UOP_ID+1];\n")
out.emit("extern const uint8_t _PyUop_Replication[MAX_UOP_ID+1];\n")
out.emit("extern const char * const _PyOpcode_uop_name[MAX_UOP_ID+1];\n\n")
out.emit("extern int _PyUop_num_popped(int opcode, int oparg);\n\n")
out.emit("#ifdef NEED_OPCODE_METADATA\n")
out.emit("const uint16_t _PyUop_Flags[MAX_UOP_ID+1] = {\n")
for uop in analysis.uops.values():
@ -44,6 +45,20 @@ def generate_names_and_flags(analysis: Analysis, out: CWriter) -> None:
if uop.is_viable() and uop.properties.tier != 1:
out.emit(f'[{uop.name}] = "{uop.name}",\n')
out.emit("};\n")
out.emit("int _PyUop_num_popped(int opcode, int oparg)\n{\n")
out.emit("switch(opcode) {\n")
for uop in analysis.uops.values():
if uop.is_viable() and uop.properties.tier != 1:
stack = Stack()
for var in reversed(uop.stack.inputs):
stack.pop(var)
popped = (-stack.base_offset).to_c()
out.emit(f"case {uop.name}:\n")
out.emit(f" return {popped};\n")
out.emit("default:\n")
out.emit(" return -1;\n")
out.emit("}\n")
out.emit("}\n\n")
out.emit("#endif // NEED_OPCODE_METADATA\n\n")

View file

@ -31,6 +31,12 @@ class HoleValue(enum.Enum):
OPERAND = enum.auto()
# The current uop's target (exposed as _JIT_TARGET):
TARGET = enum.auto()
# The base address of the machine code for the jump target (exposed as _JIT_JUMP_TARGET):
JUMP_TARGET = enum.auto()
# The base address of the machine code for the error jump target (exposed as _JIT_ERROR_TARGET):
ERROR_TARGET = enum.auto()
# The index of the exit to be jumped through (exposed as _JIT_EXIT_INDEX):
EXIT_INDEX = enum.auto()
# The base address of the machine code for the first uop (exposed as _JIT_TOP):
TOP = enum.auto()
# A hardcoded value of zero (used for symbol lookups):

View file

@ -64,9 +64,17 @@ do { \
TYPE NAME = (TYPE)(uint64_t)&ALIAS;
#define PATCH_JUMP(ALIAS) \
do { \
PyAPI_DATA(void) ALIAS; \
__attribute__((musttail)) \
return ((jit_func)&ALIAS)(frame, stack_pointer, tstate);
return ((jit_func)&ALIAS)(frame, stack_pointer, tstate); \
} while (0)
#undef JUMP_TO_JUMP_TARGET
#define JUMP_TO_JUMP_TARGET() PATCH_JUMP(_JIT_JUMP_TARGET)
#undef JUMP_TO_ERROR
#define JUMP_TO_ERROR() PATCH_JUMP(_JIT_ERROR_TARGET)
_Py_CODEUNIT *
_JIT_ENTRY(_PyInterpreterFrame *frame, PyObject **stack_pointer, PyThreadState *tstate)
@ -79,6 +87,7 @@ _JIT_ENTRY(_PyInterpreterFrame *frame, PyObject **stack_pointer, PyThreadState *
PATCH_VALUE(uint16_t, _oparg, _JIT_OPARG)
PATCH_VALUE(uint64_t, _operand, _JIT_OPERAND)
PATCH_VALUE(uint32_t, _target, _JIT_TARGET)
PATCH_VALUE(uint16_t, _exit_index, _JIT_EXIT_INDEX)
// The actual instruction definitions (only one will be used):
if (opcode == _JUMP_TO_TOP) {
CHECK_EVAL_BREAKER();
@ -91,28 +100,16 @@ _JIT_ENTRY(_PyInterpreterFrame *frame, PyObject **stack_pointer, PyThreadState *
}
PATCH_JUMP(_JIT_CONTINUE);
// Labels that the instruction implementations expect to exist:
unbound_local_error_tier_two:
_PyEval_FormatExcCheckArg(
tstate, PyExc_UnboundLocalError, UNBOUNDLOCAL_ERROR_MSG,
PyTuple_GetItem(_PyFrame_GetCode(frame)->co_localsplusnames, oparg));
goto error_tier_two;
pop_4_error_tier_two:
STACK_SHRINK(1);
pop_3_error_tier_two:
STACK_SHRINK(1);
pop_2_error_tier_two:
STACK_SHRINK(1);
pop_1_error_tier_two:
STACK_SHRINK(1);
error_tier_two:
tstate->previous_executor = (PyObject *)current_executor;
GOTO_TIER_ONE(NULL);
deoptimize:
exit_to_tier1:
tstate->previous_executor = (PyObject *)current_executor;
GOTO_TIER_ONE(_PyCode_CODE(_PyFrame_GetCode(frame)) + _target);
side_exit:
exit_to_trace:
{
_PyExitData *exit = &current_executor->exits[_target];
_PyExitData *exit = &current_executor->exits[_exit_index];
Py_INCREF(exit->executor);
tstate->previous_executor = (PyObject *)current_executor;
GOTO_TIER_TWO(exit->executor);