================
@@ -219,200 +215,359 @@ class BoxedProcedurePass
   inline mlir::ModuleOp getModule() { return getOperation(); }
 
   void runOnOperation() override final {
-    if (options.useThunks) {
+    if (useThunks) {
       auto *context = &getContext();
       mlir::IRRewriter rewriter(context);
       BoxprocTypeRewriter typeConverter(mlir::UnknownLoc::get(context));
-      getModule().walk([&](mlir::Operation *op) {
-        bool opIsValid = true;
-        typeConverter.setLocation(op->getLoc());
-        if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) {
-          mlir::Type ty = addr.getVal().getType();
-          mlir::Type resTy = addr.getResult().getType();
-          if (llvm::isa<mlir::FunctionType>(ty) ||
-              llvm::isa<fir::BoxProcType>(ty)) {
-            // Rewrite all `fir.box_addr` ops on values of type `!fir.boxproc`
-            // or function type to be `fir.convert` ops.
-            rewriter.setInsertionPoint(addr);
-            rewriter.replaceOpWithNewOp<ConvertOp>(
-                addr, typeConverter.convertType(addr.getType()), 
addr.getVal());
-            opIsValid = false;
-          } else if (typeConverter.needsConversion(resTy)) {
-            rewriter.startOpModification(op);
-            op->getResult(0).setType(typeConverter.convertType(resTy));
-            rewriter.finalizeOpModification(op);
-          }
-        } else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) {
-          mlir::FunctionType ty = func.getFunctionType();
-          if (typeConverter.needsConversion(ty)) {
-            rewriter.startOpModification(func);
-            auto toTy =
-                mlir::cast<mlir::FunctionType>(typeConverter.convertType(ty));
-            if (!func.empty())
-              for (auto e : llvm::enumerate(toTy.getInputs())) {
-                unsigned i = e.index();
-                auto &block = func.front();
-                block.insertArgument(i, e.value(), func.getLoc());
-                block.getArgument(i + 1).replaceAllUsesWith(
-                    block.getArgument(i));
-                block.eraseArgument(i + 1);
-              }
-            func.setType(toTy);
-            rewriter.finalizeOpModification(func);
+
+      // When using safe trampolines, we need to track handles per
+      // function so we can insert FreeTrampoline calls at each return.
+      // Process functions individually to manage this state.
+      if (useSafeTrampoline) {
+        getModule().walk([&](mlir::func::FuncOp funcOp) {
+          trampolineHandles.clear();
+          trampolineCallableMap.clear();
+          processFunction(funcOp, rewriter, typeConverter);
+          insertTrampolineFrees(funcOp, rewriter);
+        });
+        // Also process non-function ops at module level (globals, etc.)
+        processModuleLevelOps(rewriter, typeConverter);
+      } else {
+        getModule().walk([&](mlir::Operation *op) {
+          processOp(op, rewriter, typeConverter);
+        });
+      }
+    }
+  }
+
+private:
+  /// Trampoline handles collected while processing a function.
+  /// Each entry is a Value representing the opaque handle returned
+  /// by _FortranATrampolineInit, which must be freed before the
+  /// function returns.
+  llvm::SmallVector<mlir::Value> trampolineHandles;
+
+  /// Cache of trampoline callable addresses keyed by the func SSA value
+  /// of the emboxproc. This deduplicates trampolines when the same
+  /// internal procedure is emboxed multiple times in one host function.
+  llvm::DenseMap<mlir::Value, mlir::Value> trampolineCallableMap;
+
+  /// Process all ops within a function.
+  void processFunction(mlir::func::FuncOp funcOp, mlir::IRRewriter &rewriter,
+                       BoxprocTypeRewriter &typeConverter) {
+    funcOp.walk(
+        [&](mlir::Operation *op) { processOp(op, rewriter, typeConverter); });
+  }
+
+  /// Process non-function ops at module level (globals, etc.)
+  void processModuleLevelOps(mlir::IRRewriter &rewriter,
+                             BoxprocTypeRewriter &typeConverter) {
+    for (auto &op : getModule().getBody()->getOperations())
+      if (!mlir::isa<mlir::func::FuncOp>(op))
+        processOp(&op, rewriter, typeConverter);
+  }
+
+  /// Insert _FortranATrampolineFree calls before every return in the function.
+  void insertTrampolineFrees(mlir::func::FuncOp funcOp,
+                             mlir::IRRewriter &rewriter) {
+    if (trampolineHandles.empty())
+      return;
+
+    auto module{funcOp->getParentOfType<mlir::ModuleOp>()};
+    // Insert TrampolineFree calls before every func.return in this function.
+    // At this pass stage (after CFGConversion), func.return is the only
+    // terminator that exits the function. Other terminators are either
+    // intra-function branches (cf.br, cf.cond_br, fir.select*) or
+    // fir.unreachable (after STOP/ERROR STOP), which don't need cleanup
+    // since the process is terminating.
+    funcOp.walk([&](mlir::func::ReturnOp retOp) {
+      rewriter.setInsertionPoint(retOp);
+      FirOpBuilder builder(rewriter, module);
+      auto loc{retOp.getLoc()};
+      for (mlir::Value handle : trampolineHandles)
+        fir::runtime::genTrampolineFree(builder, loc, handle);
+    });
+  }
+
+  /// Process a single operation for boxproc type rewriting.
+  void processOp(mlir::Operation *op, mlir::IRRewriter &rewriter,
+                 BoxprocTypeRewriter &typeConverter) {
+    bool opIsValid{true};
+    typeConverter.setLocation(op->getLoc());
+    if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) {
+      mlir::Type ty{addr.getVal().getType()};
+      mlir::Type resTy{addr.getResult().getType()};
+      if (llvm::isa<mlir::FunctionType>(ty) ||
+          llvm::isa<fir::BoxProcType>(ty)) {
+        // Rewrite all `fir.box_addr` ops on values of type `!fir.boxproc`
+        // or function type to be `fir.convert` ops.
+        rewriter.setInsertionPoint(addr);
+        rewriter.replaceOpWithNewOp<ConvertOp>(
+            addr, typeConverter.convertType(addr.getType()), addr.getVal());
+        opIsValid = false;
+      } else if (typeConverter.needsConversion(resTy)) {
+        rewriter.startOpModification(op);
+        op->getResult(0).setType(typeConverter.convertType(resTy));
+        rewriter.finalizeOpModification(op);
+      }
+    } else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) {
+      mlir::FunctionType ty{func.getFunctionType()};
+      if (typeConverter.needsConversion(ty)) {
+        rewriter.startOpModification(func);
+        auto toTy{
+            mlir::cast<mlir::FunctionType>(typeConverter.convertType(ty))};
+        if (!func.empty())
+          for (auto e : llvm::enumerate(toTy.getInputs())) {
+            auto i{static_cast<unsigned>(e.index())};
+            auto &block{func.front()};
+            block.insertArgument(i, e.value(), func.getLoc());
+            block.getArgument(i + 1).replaceAllUsesWith(block.getArgument(i));
+            block.eraseArgument(i + 1);
           }
-        } else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) {
-          // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk
-          // as required.
-          mlir::Type toTy = typeConverter.convertType(
-              mlir::cast<BoxProcType>(embox.getType()).getEleTy());
-          rewriter.setInsertionPoint(embox);
-          if (embox.getHost()) {
-            // Create the thunk.
-            auto module = embox->getParentOfType<mlir::ModuleOp>();
-            FirOpBuilder builder(rewriter, module);
-            const auto triple{fir::getTargetTriple(module)};
-            auto loc = embox.getLoc();
-            mlir::Type i8Ty = builder.getI8Type();
-            mlir::Type i8Ptr = builder.getRefType(i8Ty);
-            // For PPC32 and PPC64, the thunk is populated by a call to
-            // __trampoline_setup, which is defined in
-            // compiler-rt/lib/builtins/trampoline_setup.c and requires the
-            // thunk size greater than 32 bytes.  For AArch64, RISCV and 
x86_64,
-            // the thunk setup doesn't go through __trampoline_setup and fits 
in
-            // 32 bytes.
-            fir::SequenceType::Extent thunkSize = triple.getTrampolineSize();
-            mlir::Type buffTy = SequenceType::get({thunkSize}, i8Ty);
-            auto buffer = AllocaOp::create(builder, loc, buffTy);
-            mlir::Value closure =
-                builder.createConvert(loc, i8Ptr, embox.getHost());
-            mlir::Value tramp = builder.createConvert(loc, i8Ptr, buffer);
-            mlir::Value func =
-                builder.createConvert(loc, i8Ptr, embox.getFunc());
-            fir::CallOp::create(
-                builder, loc, factory::getLlvmInitTrampoline(builder),
-                llvm::ArrayRef<mlir::Value>{tramp, func, closure});
-            auto adjustCall = fir::CallOp::create(
-                builder, loc, factory::getLlvmAdjustTrampoline(builder),
-                llvm::ArrayRef<mlir::Value>{tramp});
+        func.setType(toTy);
+        rewriter.finalizeOpModification(func);
+      }
+    } else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) {
+      // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk
+      // as required.
+      mlir::Type toTy{typeConverter.convertType(
+          mlir::cast<BoxProcType>(embox.getType()).getEleTy())};
+      rewriter.setInsertionPoint(embox);
+      if (embox.getHost()) {
+        auto module{embox->getParentOfType<mlir::ModuleOp>()};
+        auto loc{embox.getLoc()};
+
+        if (useSafeTrampoline) {
+          // Runtime trampoline pool path (W^X compliant).
+          // Insert Init/Adjust in the function's entry block so the
+          // handle dominates all func.return ops where TrampolineFree
+          // is emitted. This is necessary because fir.emboxproc may
+          // appear inside control flow branches. A cache avoids
+          // creating duplicate trampolines for the same internal
+          // procedure within a single host function.
+          mlir::Value funcVal{embox.getFunc()};
+          auto cacheIt{trampolineCallableMap.find(funcVal)};
+          if (cacheIt != trampolineCallableMap.end()) {
             rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
-                                                   adjustCall.getResult(0));
-            opIsValid = false;
+                                                   cacheIt->second);
           } else {
-            // Just forward the function as a pointer.
-            rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
-                                                   embox.getFunc());
-            opIsValid = false;
-          }
-        } else if (auto global = mlir::dyn_cast<GlobalOp>(op)) {
-          auto ty = global.getType();
-          if (typeConverter.needsConversion(ty)) {
-            rewriter.startOpModification(global);
-            auto toTy = typeConverter.convertType(ty);
-            global.setType(toTy);
-            rewriter.finalizeOpModification(global);
-          }
-        } else if (auto mem = mlir::dyn_cast<AllocaOp>(op)) {
-          auto ty = mem.getType();
-          if (typeConverter.needsConversion(ty)) {
-            rewriter.setInsertionPoint(mem);
-            auto toTy = typeConverter.convertType(unwrapRefType(ty));
-            bool isPinned = mem.getPinned();
-            llvm::StringRef uniqName =
-                mem.getUniqName().value_or(llvm::StringRef());
-            llvm::StringRef bindcName =
-                mem.getBindcName().value_or(llvm::StringRef());
-            rewriter.replaceOpWithNewOp<AllocaOp>(
-                mem, toTy, uniqName, bindcName, isPinned, mem.getTypeparams(),
-                mem.getShape());
-            opIsValid = false;
-          }
-        } else if (auto mem = mlir::dyn_cast<AllocMemOp>(op)) {
-          auto ty = mem.getType();
-          if (typeConverter.needsConversion(ty)) {
-            rewriter.setInsertionPoint(mem);
-            auto toTy = typeConverter.convertType(unwrapRefType(ty));
-            llvm::StringRef uniqName =
-                mem.getUniqName().value_or(llvm::StringRef());
-            llvm::StringRef bindcName =
-                mem.getBindcName().value_or(llvm::StringRef());
-            rewriter.replaceOpWithNewOp<AllocMemOp>(
-                mem, toTy, uniqName, bindcName, mem.getTypeparams(),
-                mem.getShape());
-            opIsValid = false;
-          }
-        } else if (auto coor = mlir::dyn_cast<CoordinateOp>(op)) {
-          auto ty = coor.getType();
-          mlir::Type baseTy = coor.getBaseType();
-          if (typeConverter.needsConversion(ty) ||
-              typeConverter.needsConversion(baseTy)) {
-            rewriter.setInsertionPoint(coor);
-            auto toTy = typeConverter.convertType(ty);
-            auto toBaseTy = typeConverter.convertType(baseTy);
-            rewriter.replaceOpWithNewOp<CoordinateOp>(
-                coor, toTy, coor.getRef(), coor.getCoor(), toBaseTy,
-                coor.getFieldIndicesAttr());
-            opIsValid = false;
-          }
-        } else if (auto index = mlir::dyn_cast<FieldIndexOp>(op)) {
-          auto ty = index.getType();
-          mlir::Type onTy = index.getOnType();
-          if (typeConverter.needsConversion(ty) ||
-              typeConverter.needsConversion(onTy)) {
-            rewriter.setInsertionPoint(index);
-            auto toTy = typeConverter.convertType(ty);
-            auto toOnTy = typeConverter.convertType(onTy);
-            rewriter.replaceOpWithNewOp<FieldIndexOp>(
-                index, toTy, index.getFieldId(), toOnTy, 
index.getTypeparams());
-            opIsValid = false;
-          }
-        } else if (auto index = mlir::dyn_cast<LenParamIndexOp>(op)) {
-          auto ty = index.getType();
-          mlir::Type onTy = index.getOnType();
-          if (typeConverter.needsConversion(ty) ||
-              typeConverter.needsConversion(onTy)) {
-            rewriter.setInsertionPoint(index);
-            auto toTy = typeConverter.convertType(ty);
-            auto toOnTy = typeConverter.convertType(onTy);
-            rewriter.replaceOpWithNewOp<LenParamIndexOp>(
-                index, toTy, index.getFieldId(), toOnTy, 
index.getTypeparams());
-            opIsValid = false;
-          }
-        } else {
-          rewriter.startOpModification(op);
-          // Convert the operands if needed
-          for (auto i : llvm::enumerate(op->getResultTypes()))
-            if (typeConverter.needsConversion(i.value())) {
-              auto toTy = typeConverter.convertType(i.value());
-              op->getResult(i.index()).setType(toTy);
+            auto parentFunc{embox->getParentOfType<mlir::func::FuncOp>()};
+            auto &entryBlock{parentFunc.front()};
+
+            auto savedIP{rewriter.saveInsertionPoint()};
+
+            // Find the right insertion point in the entry block.
+            // Walk up from the emboxproc to find its top-level
+            // ancestor in the entry block. For an emboxproc directly
+            // in the entry block, this is the emboxproc itself.
+            // For one inside a structured op (fir.if, fir.do_loop),
+            // this is that structured op. For one inside an explicit
+            // branch target (cf.cond_br → ^bb1), we fall back to the
+            // entry block terminator.
+            mlir::Operation *entryAncestor{embox.getOperation()};
+            while (entryAncestor->getBlock() != &entryBlock) {
+              entryAncestor = entryAncestor->getParentOp();
+              if (!entryAncestor ||
+                  mlir::isa<mlir::func::FuncOp>(entryAncestor))
+                break;
             }
+            bool ancestorInEntry{
+                entryAncestor &&
+                !mlir::isa<mlir::func::FuncOp>(entryAncestor) &&
+                entryAncestor->getBlock() == &entryBlock};
 
-          // Convert the type attributes if needed
-          for (const mlir::NamedAttribute &attr : op->getAttrDictionary())
-            if (auto tyAttr = llvm::dyn_cast<mlir::TypeAttr>(attr.getValue()))
-              if (typeConverter.needsConversion(tyAttr.getValue())) {
-                auto toTy = typeConverter.convertType(tyAttr.getValue());
-                op->setAttr(attr.getName(), mlir::TypeAttr::get(toTy));
+            // If the func value is not in the entry block (e.g.,
+            // address_of generated inside a structured fir.if),
+            // clone it into the entry block.
+            mlir::Value funcValInEntry{funcVal};
+            if (auto *funcDef{funcVal.getDefiningOp()}) {
+              if (funcDef->getBlock() != &entryBlock) {
+                if (ancestorInEntry)
+                  rewriter.setInsertionPoint(entryAncestor);
+                else
+                  rewriter.setInsertionPoint(entryBlock.getTerminator());
+                auto *cloned{rewriter.clone(*funcDef)};
+                funcValInEntry = cloned->getResult(0);
               }
-          rewriter.finalizeOpModification(op);
+            }
+
+            // Similarly clone the host value if not in entry block.
+            mlir::Value hostValInEntry{embox.getHost()};
+            if (auto *hostDef{embox.getHost().getDefiningOp()}) {
+              if (hostDef->getBlock() != &entryBlock) {
+                if (ancestorInEntry)
+                  rewriter.setInsertionPoint(entryAncestor);
+                else
+                  rewriter.setInsertionPoint(entryBlock.getTerminator());
+                auto *cloned{rewriter.clone(*hostDef)};
+                hostValInEntry = cloned->getResult(0);
+              }
----------------
jeanPerier wrote:

I think the best you can do here is to raise an error if the hostlink was not 
created inside the entry block because cloning just the defining op is not 
enough (it would most likely just clone an alloca without all the stores to it).

In practice host links are always allocated in the entry block in the host, and 
are function block arguments inside internal procedure, so there is no need to 
do cloning for it. But if there were, it would be best to raise an error an 
abort because the simple cloning logic you have would not be enough.

The cloning of the `funcDef` makes sense and you should keep it since the 
address_of are generated on the fly.

https://github.com/llvm/llvm-project/pull/183108
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to