https://github.com/python/cpython/commit/49baa656cb994122869bc807a88ea2f3f0d7751b
commit: 49baa656cb994122869bc807a88ea2f3f0d7751b
branch: main
author: Brandt Bucher <[email protected]>
committer: brandtbucher <[email protected]>
date: 2024-05-01T08:05:53-07:00
summary:

GH-115802: Use the GHC calling convention in JIT code (GH-118287)

files:
A Tools/jit/trampoline.c
M Include/cpython/optimizer.h
M Lib/test/test_perf_profiler.py
M Python/jit.c
M Python/optimizer.c
M Tools/jit/_targets.py
M Tools/jit/_writer.py
M Tools/jit/template.c

diff --git a/Include/cpython/optimizer.h b/Include/cpython/optimizer.h
index a169280b26a6ad..60b35747deb4f3 100644
--- a/Include/cpython/optimizer.h
+++ b/Include/cpython/optimizer.h
@@ -102,6 +102,7 @@ typedef struct _PyExecutorObject {
     uint32_t code_size;
     size_t jit_size;
     void *jit_code;
+    void *jit_side_entry;
     _PyExitData exits[1];
 } _PyExecutorObject;
 
diff --git a/Lib/test/test_perf_profiler.py b/Lib/test/test_perf_profiler.py
index 040be63da11447..e7c03b99086013 100644
--- a/Lib/test/test_perf_profiler.py
+++ b/Lib/test/test_perf_profiler.py
@@ -216,7 +216,7 @@ def is_unwinding_reliable():
     cflags = sysconfig.get_config_var("PY_CORE_CFLAGS")
     if not cflags:
         return False
-    return "no-omit-frame-pointer" in cflags
+    return "no-omit-frame-pointer" in cflags and "_Py_JIT" not in cflags
 
 
 def perf_command_works():
diff --git a/Python/jit.c b/Python/jit.c
index 75ec4fb9756eb7..df14e48c564447 100644
--- a/Python/jit.c
+++ b/Python/jit.c
@@ -385,8 +385,8 @@ _PyJIT_Compile(_PyExecutorObject *executor, const 
_PyUOpInstruction *trace, size
 {
     // Loop once to find the total compiled size:
     size_t instruction_starts[UOP_MAX_TRACE_LENGTH];
-    size_t code_size = 0;
-    size_t data_size = 0;
+    size_t code_size = trampoline.code.body_size;
+    size_t data_size = trampoline.data.body_size;
     for (size_t i = 0; i < length; i++) {
         _PyUOpInstruction *instruction = (_PyUOpInstruction *)&trace[i];
         const StencilGroup *group = &stencil_groups[instruction->opcode];
@@ -408,11 +408,29 @@ _PyJIT_Compile(_PyExecutorObject *executor, const 
_PyUOpInstruction *trace, size
     // Loop again to emit the code:
     unsigned char *code = memory;
     unsigned char *data = memory + code_size;
+    {
+        // Compile the trampoline, which handles converting between the native
+        // calling convention and the calling convention used by jitted code
+        // (which may be different for efficiency reasons). On platforms where
+        // we don't change calling conventions, the trampoline is empty and
+        // nothing is emitted here:
+        const StencilGroup *group = &trampoline;
+        // Think of patches as a dictionary mapping HoleValue to uintptr_t:
+        uintptr_t patches[] = GET_PATCHES();
+        patches[HoleValue_CODE] = (uintptr_t)code;
+        patches[HoleValue_CONTINUE] = (uintptr_t)code + group->code.body_size;
+        patches[HoleValue_DATA] = (uintptr_t)data;
+        patches[HoleValue_EXECUTOR] = (uintptr_t)executor;
+        patches[HoleValue_TOP] = (uintptr_t)memory + trampoline.code.body_size;
+        patches[HoleValue_ZERO] = 0;
+        emit(group, patches);
+        code += group->code.body_size;
+        data += group->data.body_size;
+    }
     assert(trace[0].opcode == _START_EXECUTOR || trace[0].opcode == 
_COLD_EXIT);
     for (size_t i = 0; i < length; i++) {
         _PyUOpInstruction *instruction = (_PyUOpInstruction *)&trace[i];
         const StencilGroup *group = &stencil_groups[instruction->opcode];
-        // Think of patches as a dictionary mapping HoleValue to uintptr_t:
         uintptr_t patches[] = GET_PATCHES();
         patches[HoleValue_CODE] = (uintptr_t)code;
         patches[HoleValue_CONTINUE] = (uintptr_t)code + group->code.body_size;
@@ -454,18 +472,20 @@ _PyJIT_Compile(_PyExecutorObject *executor, const 
_PyUOpInstruction *trace, size
         code += group->code.body_size;
         data += group->data.body_size;
     }
-    // Protect against accidental buffer overrun into data:
-    const StencilGroup *group = &stencil_groups[_FATAL_ERROR];
-    uintptr_t patches[] = GET_PATCHES();
-    patches[HoleValue_CODE] = (uintptr_t)code;
-    patches[HoleValue_CONTINUE] = (uintptr_t)code;
-    patches[HoleValue_DATA] = (uintptr_t)data;
-    patches[HoleValue_EXECUTOR] = (uintptr_t)executor;
-    patches[HoleValue_TOP] = (uintptr_t)code;
-    patches[HoleValue_ZERO] = 0;
-    emit(group, patches);
-    code += group->code.body_size;
-    data += group->data.body_size;
+    {
+        // Protect against accidental buffer overrun into data:
+        const StencilGroup *group = &stencil_groups[_FATAL_ERROR];
+        uintptr_t patches[] = GET_PATCHES();
+        patches[HoleValue_CODE] = (uintptr_t)code;
+        patches[HoleValue_CONTINUE] = (uintptr_t)code;
+        patches[HoleValue_DATA] = (uintptr_t)data;
+        patches[HoleValue_EXECUTOR] = (uintptr_t)executor;
+        patches[HoleValue_TOP] = (uintptr_t)code;
+        patches[HoleValue_ZERO] = 0;
+        emit(group, patches);
+        code += group->code.body_size;
+        data += group->data.body_size;
+    }
     assert(code == memory + code_size);
     assert(data == memory + code_size + data_size);
     if (mark_executable(memory, total_size)) {
@@ -473,6 +493,7 @@ _PyJIT_Compile(_PyExecutorObject *executor, const 
_PyUOpInstruction *trace, size
         return -1;
     }
     executor->jit_code = memory;
+    executor->jit_side_entry = memory + trampoline.code.body_size;
     executor->jit_size = total_size;
     return 0;
 }
@@ -484,6 +505,7 @@ _PyJIT_Free(_PyExecutorObject *executor)
     size_t size = executor->jit_size;
     if (memory) {
         executor->jit_code = NULL;
+        executor->jit_side_entry = NULL;
         executor->jit_size = 0;
         if (jit_free(memory, size)) {
             PyErr_WriteUnraisable(NULL);
diff --git a/Python/optimizer.c b/Python/optimizer.c
index 9ba8d84a47dcd9..2389338531b0f3 100644
--- a/Python/optimizer.c
+++ b/Python/optimizer.c
@@ -1188,6 +1188,7 @@ make_executor_from_uops(_PyUOpInstruction *buffer, int 
length, const _PyBloomFil
 #endif
 #ifdef _Py_JIT
     executor->jit_code = NULL;
+    executor->jit_side_entry = NULL;
     executor->jit_size = 0;
     if (_PyJIT_Compile(executor, executor->trace, length)) {
         Py_DECREF(executor);
@@ -1219,6 +1220,7 @@ init_cold_exit_executor(_PyExecutorObject *executor, int 
oparg)
 #endif
 #ifdef _Py_JIT
     executor->jit_code = NULL;
+    executor->jit_side_entry = NULL;
     executor->jit_size = 0;
     if (_PyJIT_Compile(executor, executor->trace, 1)) {
         return -1;
diff --git a/Tools/jit/_targets.py b/Tools/jit/_targets.py
index 91734b36b4ab1b..274d17bcf38deb 100644
--- a/Tools/jit/_targets.py
+++ b/Tools/jit/_targets.py
@@ -38,6 +38,7 @@ class _Target(typing.Generic[_S, _R]):
     _: dataclasses.KW_ONLY
     alignment: int = 1
     args: typing.Sequence[str] = ()
+    ghccc: bool = False
     prefix: str = ""
     debug: bool = False
     force: bool = False
@@ -85,7 +86,11 @@ async def _parse(self, path: pathlib.Path) -> 
_stencils.StencilGroup:
         sections: list[dict[typing.Literal["Section"], _S]] = 
json.loads(output)
         for wrapped_section in sections:
             self._handle_section(wrapped_section["Section"], group)
-        assert group.symbols["_JIT_ENTRY"] == (_stencils.HoleValue.CODE, 0)
+        # The trampoline's entry point is just named "_ENTRY", since on some
+        # platforms we later assume that any function starting with "_JIT_" 
uses
+        # the GHC calling convention:
+        entry_symbol = "_JIT_ENTRY" if "_JIT_ENTRY" in group.symbols else 
"_ENTRY"
+        assert group.symbols[entry_symbol] == (_stencils.HoleValue.CODE, 0)
         if group.data.body:
             line = f"0: {str(bytes(group.data.body)).removeprefix('b')}"
             group.data.disassembly.append(line)
@@ -103,6 +108,9 @@ def _handle_relocation(
     async def _compile(
         self, opname: str, c: pathlib.Path, tempdir: pathlib.Path
     ) -> _stencils.StencilGroup:
+        # "Compile" the trampoline to an empty stencil group if it's not 
needed:
+        if opname == "trampoline" and not self.ghccc:
+            return _stencils.StencilGroup()
         o = tempdir / f"{opname}.o"
         args = [
             f"--target={self.triple}",
@@ -130,13 +138,38 @@ async def _compile(
             "-fno-plt",
             # Don't call stack-smashing canaries that we can't find or patch:
             "-fno-stack-protector",
-            "-o",
-            f"{o}",
             "-std=c11",
-            f"{c}",
             *self.args,
         ]
-        await _llvm.run("clang", args, echo=self.verbose)
+        if self.ghccc:
+            # This is a bit of an ugly workaround, but it makes the code much
+            # smaller and faster, so it's worth it. We want to use the GHC
+            # calling convention, but Clang doesn't support it. So, we *first*
+            # compile the code to LLVM IR, perform some text replacements on 
the
+            # IR to change the calling convention(!), and then compile *that*.
+            # Once we have access to Clang 19, we can get rid of this and use
+            # __attribute__((preserve_none)) directly in the C code instead:
+            ll = tempdir / f"{opname}.ll"
+            args_ll = args + [
+                # -fomit-frame-pointer is necessary because the GHC calling
+                # convention uses RBP to pass arguments:
+                "-S", "-emit-llvm", "-fomit-frame-pointer", "-o", f"{ll}", 
f"{c}"
+            ]
+            await _llvm.run("clang", args_ll, echo=self.verbose)
+            ir = ll.read_text()
+            # This handles declarations, definitions, and calls to named 
symbols
+            # starting with "_JIT_":
+            ir = re.sub(r"(((noalias|nonnull|noundef) )*ptr @_JIT_\w+\()", 
r"ghccc \1", ir)
+            # This handles calls to anonymous callees, since anything with
+            # "musttail" needs to use the same calling convention:
+            ir = ir.replace("musttail call", "musttail call ghccc")
+            # Sometimes *both* replacements happen at the same site, so fix it:
+            ir = ir.replace("ghccc ghccc", "ghccc")
+            ll.write_text(ir)
+            args_o = args + ["-Wno-unused-command-line-argument", "-o", 
f"{o}", f"{ll}"]
+        else:
+            args_o = args + ["-o", f"{o}", f"{c}"]
+        await _llvm.run("clang", args_o, echo=self.verbose)
         return await self._parse(o)
 
     async def _build_stencils(self) -> dict[str, _stencils.StencilGroup]:
@@ -146,6 +179,8 @@ async def _build_stencils(self) -> dict[str, 
_stencils.StencilGroup]:
         with tempfile.TemporaryDirectory() as tempdir:
             work = pathlib.Path(tempdir).resolve()
             async with asyncio.TaskGroup() as group:
+                coro = self._compile("trampoline", TOOLS_JIT / "trampoline.c", 
work)
+                tasks.append(group.create_task(coro, name="trampoline"))
                 for opname in opnames:
                     coro = self._compile(opname, TOOLS_JIT_TEMPLATE_C, work)
                     tasks.append(group.create_task(coro, name=opname))
@@ -445,6 +480,7 @@ def _handle_relocation(
 
 def get_target(host: str) -> _COFF | _ELF | _MachO:
     """Build a _Target for the given host "triple" and options."""
+    # ghccc currently crashes Clang when combined with musttail on aarch64. :(
     if re.fullmatch(r"aarch64-apple-darwin.*", host):
         return _MachO(host, alignment=8, prefix="_")
     if re.fullmatch(r"aarch64-pc-windows-msvc", host):
@@ -455,13 +491,13 @@ def get_target(host: str) -> _COFF | _ELF | _MachO:
         return _ELF(host, alignment=8, args=args)
     if re.fullmatch(r"i686-pc-windows-msvc", host):
         args = ["-DPy_NO_ENABLE_SHARED"]
-        return _COFF(host, args=args, prefix="_")
+        return _COFF(host, args=args, ghccc=True, prefix="_")
     if re.fullmatch(r"x86_64-apple-darwin.*", host):
-        return _MachO(host, prefix="_")
+        return _MachO(host, ghccc=True, prefix="_")
     if re.fullmatch(r"x86_64-pc-windows-msvc", host):
         args = ["-fms-runtime-lib=dll"]
-        return _COFF(host, args=args)
+        return _COFF(host, args=args, ghccc=True)
     if re.fullmatch(r"x86_64-.*-linux-gnu", host):
         args = ["-fpic"]
-        return _ELF(host, args=args)
+        return _ELF(host, args=args, ghccc=True)
     raise ValueError(host)
diff --git a/Tools/jit/_writer.py b/Tools/jit/_writer.py
index cbc1ed2fa6543a..6b36d8a9c66a3f 100644
--- a/Tools/jit/_writer.py
+++ b/Tools/jit/_writer.py
@@ -53,9 +53,13 @@ def _dump_footer(opnames: typing.Iterable[str]) -> 
typing.Iterator[str]:
     yield ""
     yield "static const StencilGroup stencil_groups[512] = {"
     for opname in opnames:
+        if opname == "trampoline":
+            continue
         yield f"    [{opname}] = INIT_STENCIL_GROUP({opname}),"
     yield "};"
     yield ""
+    yield "static const StencilGroup trampoline = 
INIT_STENCIL_GROUP(trampoline);"
+    yield ""
     yield "#define GET_PATCHES() { \\"
     for value in _stencils.HoleValue:
         yield f"    [HoleValue_{value.name}] = (uintptr_t)0xBADBADBADBADBADB, 
\\"
diff --git a/Tools/jit/template.c b/Tools/jit/template.c
index 3e81fd15bb8093..0dd0744f7aec9c 100644
--- a/Tools/jit/template.c
+++ b/Tools/jit/template.c
@@ -48,7 +48,7 @@
 do {  \
     OPT_STAT_INC(traces_executed);                \
     __attribute__((musttail))                     \
-    return ((jit_func)((EXECUTOR)->jit_code))(frame, stack_pointer, tstate); \
+    return ((jit_func)((EXECUTOR)->jit_side_entry))(frame, stack_pointer, 
tstate); \
 } while (0)
 
 #undef GOTO_TIER_ONE
@@ -65,7 +65,7 @@ do {  \
 
 #define PATCH_VALUE(TYPE, NAME, ALIAS)  \
     PyAPI_DATA(void) ALIAS;             \
-    TYPE NAME = (TYPE)(uint64_t)&ALIAS;
+    TYPE NAME = (TYPE)(uintptr_t)&ALIAS;
 
 #define PATCH_JUMP(ALIAS)                                    \
 do {                                                         \
diff --git a/Tools/jit/trampoline.c b/Tools/jit/trampoline.c
new file mode 100644
index 00000000000000..01b3d63a6790ba
--- /dev/null
+++ b/Tools/jit/trampoline.c
@@ -0,0 +1,25 @@
+#include "Python.h"
+
+#include "pycore_ceval.h"
+#include "pycore_frame.h"
+#include "pycore_jit.h"
+
+// This is where the calling convention changes, on platforms that require it.
+// The actual change is patched in while the JIT compiler is being built, in
+// Tools/jit/_targets.py. On other platforms, this function compiles to 
nothing.
+_Py_CODEUNIT *
+_ENTRY(_PyInterpreterFrame *frame, PyObject **stack_pointer, PyThreadState 
*tstate)
+{
+    // This is subtle. The actual trace will return to us once it exits, so we
+    // need to make sure that we stay alive until then. If our trace side-exits
+    // into another trace, and this trace is then invalidated, the code for
+    // *this function* will be freed and we'll crash upon return:
+    PyAPI_DATA(void) _JIT_EXECUTOR;
+    PyObject *executor = (PyObject *)(uintptr_t)&_JIT_EXECUTOR;
+    Py_INCREF(executor);
+    // Note that this is *not* a tail call:
+    PyAPI_DATA(void) _JIT_CONTINUE;
+    _Py_CODEUNIT *target = ((jit_func)&_JIT_CONTINUE)(frame, stack_pointer, 
tstate);
+    Py_SETREF(tstate->previous_executor, executor);
+    return target;
+}

_______________________________________________
Python-checkins mailing list -- [email protected]
To unsubscribe send an email to [email protected]
https://mail.python.org/mailman3/lists/python-checkins.python.org/
Member address: [email protected]

Reply via email to