================
@@ -219,200 +215,280 @@ class BoxedProcedurePass
   inline mlir::ModuleOp getModule() { return getOperation(); }
 
   void runOnOperation() override final {
-    if (options.useThunks) {
+    if (useThunks) {
       auto *context = &getContext();
       mlir::IRRewriter rewriter(context);
       BoxprocTypeRewriter typeConverter(mlir::UnknownLoc::get(context));
-      getModule().walk([&](mlir::Operation *op) {
-        bool opIsValid = true;
-        typeConverter.setLocation(op->getLoc());
-        if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) {
-          mlir::Type ty = addr.getVal().getType();
-          mlir::Type resTy = addr.getResult().getType();
-          if (llvm::isa<mlir::FunctionType>(ty) ||
-              llvm::isa<fir::BoxProcType>(ty)) {
-            // Rewrite all `fir.box_addr` ops on values of type `!fir.boxproc`
-            // or function type to be `fir.convert` ops.
-            rewriter.setInsertionPoint(addr);
-            rewriter.replaceOpWithNewOp<ConvertOp>(
-                addr, typeConverter.convertType(addr.getType()), 
addr.getVal());
-            opIsValid = false;
-          } else if (typeConverter.needsConversion(resTy)) {
-            rewriter.startOpModification(op);
-            op->getResult(0).setType(typeConverter.convertType(resTy));
-            rewriter.finalizeOpModification(op);
-          }
-        } else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) {
-          mlir::FunctionType ty = func.getFunctionType();
-          if (typeConverter.needsConversion(ty)) {
-            rewriter.startOpModification(func);
-            auto toTy =
-                mlir::cast<mlir::FunctionType>(typeConverter.convertType(ty));
-            if (!func.empty())
-              for (auto e : llvm::enumerate(toTy.getInputs())) {
-                unsigned i = e.index();
-                auto &block = func.front();
-                block.insertArgument(i, e.value(), func.getLoc());
-                block.getArgument(i + 1).replaceAllUsesWith(
-                    block.getArgument(i));
-                block.eraseArgument(i + 1);
-              }
-            func.setType(toTy);
-            rewriter.finalizeOpModification(func);
-          }
-        } else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) {
-          // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk
-          // as required.
-          mlir::Type toTy = typeConverter.convertType(
-              mlir::cast<BoxProcType>(embox.getType()).getEleTy());
-          rewriter.setInsertionPoint(embox);
-          if (embox.getHost()) {
-            // Create the thunk.
-            auto module = embox->getParentOfType<mlir::ModuleOp>();
-            FirOpBuilder builder(rewriter, module);
-            const auto triple{fir::getTargetTriple(module)};
-            auto loc = embox.getLoc();
-            mlir::Type i8Ty = builder.getI8Type();
-            mlir::Type i8Ptr = builder.getRefType(i8Ty);
-            // For PPC32 and PPC64, the thunk is populated by a call to
-            // __trampoline_setup, which is defined in
-            // compiler-rt/lib/builtins/trampoline_setup.c and requires the
-            // thunk size greater than 32 bytes.  For AArch64, RISCV and 
x86_64,
-            // the thunk setup doesn't go through __trampoline_setup and fits 
in
-            // 32 bytes.
-            fir::SequenceType::Extent thunkSize = triple.getTrampolineSize();
-            mlir::Type buffTy = SequenceType::get({thunkSize}, i8Ty);
-            auto buffer = AllocaOp::create(builder, loc, buffTy);
-            mlir::Value closure =
-                builder.createConvert(loc, i8Ptr, embox.getHost());
-            mlir::Value tramp = builder.createConvert(loc, i8Ptr, buffer);
-            mlir::Value func =
-                builder.createConvert(loc, i8Ptr, embox.getFunc());
-            fir::CallOp::create(
-                builder, loc, factory::getLlvmInitTrampoline(builder),
-                llvm::ArrayRef<mlir::Value>{tramp, func, closure});
-            auto adjustCall = fir::CallOp::create(
-                builder, loc, factory::getLlvmAdjustTrampoline(builder),
-                llvm::ArrayRef<mlir::Value>{tramp});
-            rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
-                                                   adjustCall.getResult(0));
-            opIsValid = false;
-          } else {
-            // Just forward the function as a pointer.
-            rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
-                                                   embox.getFunc());
-            opIsValid = false;
-          }
-        } else if (auto global = mlir::dyn_cast<GlobalOp>(op)) {
-          auto ty = global.getType();
-          if (typeConverter.needsConversion(ty)) {
-            rewriter.startOpModification(global);
-            auto toTy = typeConverter.convertType(ty);
-            global.setType(toTy);
-            rewriter.finalizeOpModification(global);
-          }
-        } else if (auto mem = mlir::dyn_cast<AllocaOp>(op)) {
-          auto ty = mem.getType();
-          if (typeConverter.needsConversion(ty)) {
-            rewriter.setInsertionPoint(mem);
-            auto toTy = typeConverter.convertType(unwrapRefType(ty));
-            bool isPinned = mem.getPinned();
-            llvm::StringRef uniqName =
-                mem.getUniqName().value_or(llvm::StringRef());
-            llvm::StringRef bindcName =
-                mem.getBindcName().value_or(llvm::StringRef());
-            rewriter.replaceOpWithNewOp<AllocaOp>(
-                mem, toTy, uniqName, bindcName, isPinned, mem.getTypeparams(),
-                mem.getShape());
-            opIsValid = false;
-          }
-        } else if (auto mem = mlir::dyn_cast<AllocMemOp>(op)) {
-          auto ty = mem.getType();
-          if (typeConverter.needsConversion(ty)) {
-            rewriter.setInsertionPoint(mem);
-            auto toTy = typeConverter.convertType(unwrapRefType(ty));
-            llvm::StringRef uniqName =
-                mem.getUniqName().value_or(llvm::StringRef());
-            llvm::StringRef bindcName =
-                mem.getBindcName().value_or(llvm::StringRef());
-            rewriter.replaceOpWithNewOp<AllocMemOp>(
-                mem, toTy, uniqName, bindcName, mem.getTypeparams(),
-                mem.getShape());
-            opIsValid = false;
-          }
-        } else if (auto coor = mlir::dyn_cast<CoordinateOp>(op)) {
-          auto ty = coor.getType();
-          mlir::Type baseTy = coor.getBaseType();
-          if (typeConverter.needsConversion(ty) ||
-              typeConverter.needsConversion(baseTy)) {
-            rewriter.setInsertionPoint(coor);
-            auto toTy = typeConverter.convertType(ty);
-            auto toBaseTy = typeConverter.convertType(baseTy);
-            rewriter.replaceOpWithNewOp<CoordinateOp>(
-                coor, toTy, coor.getRef(), coor.getCoor(), toBaseTy,
-                coor.getFieldIndicesAttr());
-            opIsValid = false;
-          }
-        } else if (auto index = mlir::dyn_cast<FieldIndexOp>(op)) {
-          auto ty = index.getType();
-          mlir::Type onTy = index.getOnType();
-          if (typeConverter.needsConversion(ty) ||
-              typeConverter.needsConversion(onTy)) {
-            rewriter.setInsertionPoint(index);
-            auto toTy = typeConverter.convertType(ty);
-            auto toOnTy = typeConverter.convertType(onTy);
-            rewriter.replaceOpWithNewOp<FieldIndexOp>(
-                index, toTy, index.getFieldId(), toOnTy, 
index.getTypeparams());
-            opIsValid = false;
-          }
-        } else if (auto index = mlir::dyn_cast<LenParamIndexOp>(op)) {
-          auto ty = index.getType();
-          mlir::Type onTy = index.getOnType();
-          if (typeConverter.needsConversion(ty) ||
-              typeConverter.needsConversion(onTy)) {
-            rewriter.setInsertionPoint(index);
-            auto toTy = typeConverter.convertType(ty);
-            auto toOnTy = typeConverter.convertType(onTy);
-            rewriter.replaceOpWithNewOp<LenParamIndexOp>(
-                index, toTy, index.getFieldId(), toOnTy, 
index.getTypeparams());
-            opIsValid = false;
+
+      // When using safe trampolines, we need to track handles per
+      // function so we can insert FreeTrampoline calls at each return.
+      // Process functions individually to manage this state.
+      if (useSafeTrampoline) {
+        getModule().walk([&](mlir::func::FuncOp funcOp) {
+          trampolineHandles.clear();
+          processFunction(funcOp, rewriter, typeConverter);
+          insertTrampolineFrees(funcOp, rewriter);
+        });
+        // Also process non-function ops at module level (globals, etc.)
+        processModuleLevelOps(rewriter, typeConverter);
+      } else {
+        getModule().walk([&](mlir::Operation *op) {
+          processOp(op, rewriter, typeConverter);
+        });
+      }
+    }
+  }
+
+private:
+  /// Trampoline handles collected while processing a function.
+  /// Each entry is a Value representing the opaque handle returned
+  /// by _FortranATrampolineInit, which must be freed before the
+  /// function returns.
+  llvm::SmallVector<mlir::Value> trampolineHandles;
+
+  /// Process all ops within a function.
+  void processFunction(mlir::func::FuncOp funcOp, mlir::IRRewriter &rewriter,
+                       BoxprocTypeRewriter &typeConverter) {
+    funcOp.walk(
+        [&](mlir::Operation *op) { processOp(op, rewriter, typeConverter); });
+  }
+
+  /// Process non-function ops at module level (globals, etc.)
+  void processModuleLevelOps(mlir::IRRewriter &rewriter,
+                             BoxprocTypeRewriter &typeConverter) {
+    for (auto &op : getModule().getBody()->getOperations())
+      if (!mlir::isa<mlir::func::FuncOp>(op))
+        processOp(&op, rewriter, typeConverter);
+  }
+
+  /// Insert _FortranATrampolineFree calls before every return in the function.
+  void insertTrampolineFrees(mlir::func::FuncOp funcOp,
+                             mlir::IRRewriter &rewriter) {
+    if (trampolineHandles.empty())
+      return;
+
+    auto module{funcOp->getParentOfType<mlir::ModuleOp>()};
+    // Insert TrampolineFree calls before every func.return in this function.
+    // At this pass stage (after CFGConversion), func.return is the only
+    // terminator that exits the function. Other terminators are either
+    // intra-function branches (cf.br, cf.cond_br, fir.select*) or
+    // fir.unreachable (after STOP/ERROR STOP), which don't need cleanup
+    // since the process is terminating.
+    funcOp.walk([&](mlir::func::ReturnOp retOp) {
+      rewriter.setInsertionPoint(retOp);
+      FirOpBuilder builder(rewriter, module);
+      auto loc{retOp.getLoc()};
+      for (mlir::Value handle : trampolineHandles)
+        fir::runtime::genTrampolineFree(builder, loc, handle);
+    });
+  }
+
+  /// Process a single operation for boxproc type rewriting.
+  void processOp(mlir::Operation *op, mlir::IRRewriter &rewriter,
+                 BoxprocTypeRewriter &typeConverter) {
+    bool opIsValid{true};
+    typeConverter.setLocation(op->getLoc());
+    if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) {
+      mlir::Type ty{addr.getVal().getType()};
+      mlir::Type resTy{addr.getResult().getType()};
+      if (llvm::isa<mlir::FunctionType>(ty) ||
+          llvm::isa<fir::BoxProcType>(ty)) {
+        // Rewrite all `fir.box_addr` ops on values of type `!fir.boxproc`
+        // or function type to be `fir.convert` ops.
+        rewriter.setInsertionPoint(addr);
+        rewriter.replaceOpWithNewOp<ConvertOp>(
+            addr, typeConverter.convertType(addr.getType()), addr.getVal());
+        opIsValid = false;
+      } else if (typeConverter.needsConversion(resTy)) {
+        rewriter.startOpModification(op);
+        op->getResult(0).setType(typeConverter.convertType(resTy));
+        rewriter.finalizeOpModification(op);
+      }
+    } else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) {
+      mlir::FunctionType ty{func.getFunctionType()};
+      if (typeConverter.needsConversion(ty)) {
+        rewriter.startOpModification(func);
+        auto toTy{
+            mlir::cast<mlir::FunctionType>(typeConverter.convertType(ty))};
+        if (!func.empty())
+          for (auto e : llvm::enumerate(toTy.getInputs())) {
+            auto i{static_cast<unsigned>(e.index())};
+            auto &block{func.front()};
+            block.insertArgument(i, e.value(), func.getLoc());
+            block.getArgument(i + 1).replaceAllUsesWith(block.getArgument(i));
+            block.eraseArgument(i + 1);
           }
-        } else {
-          rewriter.startOpModification(op);
-          // Convert the operands if needed
-          for (auto i : llvm::enumerate(op->getResultTypes()))
-            if (typeConverter.needsConversion(i.value())) {
-              auto toTy = typeConverter.convertType(i.value());
-              op->getResult(i.index()).setType(toTy);
-            }
+        func.setType(toTy);
+        rewriter.finalizeOpModification(func);
+      }
+    } else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) {
+      // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk
+      // as required.
+      mlir::Type toTy{typeConverter.convertType(
+          mlir::cast<BoxProcType>(embox.getType()).getEleTy())};
+      rewriter.setInsertionPoint(embox);
+      if (embox.getHost()) {
+        auto module{embox->getParentOfType<mlir::ModuleOp>()};
+        FirOpBuilder builder(rewriter, module);
+        auto loc{embox.getLoc()};
+        mlir::Type i8Ty{builder.getI8Type()};
+        mlir::Type i8Ptr{builder.getRefType(i8Ty)};
+
+        if (useSafeTrampoline) {
+          // Runtime trampoline pool path (W^X compliant).
+          // Instead of allocating a writable+executable buffer on the
+          // stack, call the runtime to allocate from a pre-initialized
+          // pool with separate RX (code) and RW (data) regions.
+          mlir::Value nullPtr{builder.createNullConstant(loc, i8Ptr)};
+          mlir::Value closure{
+              builder.createConvert(loc, i8Ptr, embox.getHost())};
+          mlir::Value func{builder.createConvert(loc, i8Ptr, embox.getFunc())};
 
-          // Convert the type attributes if needed
-          for (const mlir::NamedAttribute &attr : op->getAttrDictionary())
-            if (auto tyAttr = llvm::dyn_cast<mlir::TypeAttr>(attr.getValue()))
-              if (typeConverter.needsConversion(tyAttr.getValue())) {
-                auto toTy = typeConverter.convertType(tyAttr.getValue());
-                op->setAttr(attr.getName(), mlir::TypeAttr::get(toTy));
-              }
-          rewriter.finalizeOpModification(op);
+          // _FortranATrampolineInit(nullptr, func, closure) -> handle
+          mlir::Value handle{fir::runtime::genTrampolineInit(
+              builder, loc, nullPtr, func, closure)};
+
+          // _FortranATrampolineAdjust(handle) -> callable address
+          mlir::Value callableAddr{
+              fir::runtime::genTrampolineAdjust(builder, loc, handle)};
+
+          // Track the handle so we can free it at function exits.
+          trampolineHandles.push_back(handle);
----------------
Saieiei wrote:

Done. 
Init/Adjust are now hoisted to the function's entry block so the handle 
dominates all `func.return` sites where `TrampolineFree` is emitted. For 
emboxproc ops directly in the entry block, they're inserted right before the 
emboxproc. For ops inside branches (e.g. `cf.cond_br` targets), the walk-up 
finds the top-level ancestor in the entry block (or falls back to the 
entry-block terminator) and inserts before it. I also added a 
`trampolineCallableMap` keyed by the func SSA value to deduplicate.
If the same internal procedure is emboxed multiple times in one host function, 
only one trampoline is allocated. Added FIR and Fortran test cases with 
emboxproc inside a branch (_QPtest_branch in boxproc-safe-trampoline.fir, 
host_branch in safe-trampoline.f90).
https://github.com/llvm/llvm-project/pull/183108/changes/448886e8394650220d7e1dcf6fb984205d043fcc

https://github.com/llvm/llvm-project/pull/183108
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to