================
@@ -219,200 +215,284 @@ class BoxedProcedurePass
   inline mlir::ModuleOp getModule() { return getOperation(); }
 
   void runOnOperation() override final {
-    if (options.useThunks) {
+    if (useThunks) {
       auto *context = &getContext();
       mlir::IRRewriter rewriter(context);
       BoxprocTypeRewriter typeConverter(mlir::UnknownLoc::get(context));
-      getModule().walk([&](mlir::Operation *op) {
-        bool opIsValid = true;
-        typeConverter.setLocation(op->getLoc());
-        if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) {
-          mlir::Type ty = addr.getVal().getType();
-          mlir::Type resTy = addr.getResult().getType();
-          if (llvm::isa<mlir::FunctionType>(ty) ||
-              llvm::isa<fir::BoxProcType>(ty)) {
-            // Rewrite all `fir.box_addr` ops on values of type `!fir.boxproc`
-            // or function type to be `fir.convert` ops.
-            rewriter.setInsertionPoint(addr);
-            rewriter.replaceOpWithNewOp<ConvertOp>(
-                addr, typeConverter.convertType(addr.getType()), 
addr.getVal());
-            opIsValid = false;
-          } else if (typeConverter.needsConversion(resTy)) {
-            rewriter.startOpModification(op);
-            op->getResult(0).setType(typeConverter.convertType(resTy));
-            rewriter.finalizeOpModification(op);
-          }
-        } else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) {
-          mlir::FunctionType ty = func.getFunctionType();
-          if (typeConverter.needsConversion(ty)) {
-            rewriter.startOpModification(func);
-            auto toTy =
-                mlir::cast<mlir::FunctionType>(typeConverter.convertType(ty));
-            if (!func.empty())
-              for (auto e : llvm::enumerate(toTy.getInputs())) {
-                unsigned i = e.index();
-                auto &block = func.front();
-                block.insertArgument(i, e.value(), func.getLoc());
-                block.getArgument(i + 1).replaceAllUsesWith(
-                    block.getArgument(i));
-                block.eraseArgument(i + 1);
-              }
-            func.setType(toTy);
-            rewriter.finalizeOpModification(func);
-          }
-        } else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) {
-          // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk
-          // as required.
-          mlir::Type toTy = typeConverter.convertType(
-              mlir::cast<BoxProcType>(embox.getType()).getEleTy());
-          rewriter.setInsertionPoint(embox);
-          if (embox.getHost()) {
-            // Create the thunk.
-            auto module = embox->getParentOfType<mlir::ModuleOp>();
-            FirOpBuilder builder(rewriter, module);
-            const auto triple{fir::getTargetTriple(module)};
-            auto loc = embox.getLoc();
-            mlir::Type i8Ty = builder.getI8Type();
-            mlir::Type i8Ptr = builder.getRefType(i8Ty);
-            // For PPC32 and PPC64, the thunk is populated by a call to
-            // __trampoline_setup, which is defined in
-            // compiler-rt/lib/builtins/trampoline_setup.c and requires the
-            // thunk size greater than 32 bytes.  For AArch64, RISCV and 
x86_64,
-            // the thunk setup doesn't go through __trampoline_setup and fits 
in
-            // 32 bytes.
-            fir::SequenceType::Extent thunkSize = triple.getTrampolineSize();
-            mlir::Type buffTy = SequenceType::get({thunkSize}, i8Ty);
-            auto buffer = AllocaOp::create(builder, loc, buffTy);
-            mlir::Value closure =
-                builder.createConvert(loc, i8Ptr, embox.getHost());
-            mlir::Value tramp = builder.createConvert(loc, i8Ptr, buffer);
-            mlir::Value func =
-                builder.createConvert(loc, i8Ptr, embox.getFunc());
-            fir::CallOp::create(
-                builder, loc, factory::getLlvmInitTrampoline(builder),
-                llvm::ArrayRef<mlir::Value>{tramp, func, closure});
-            auto adjustCall = fir::CallOp::create(
-                builder, loc, factory::getLlvmAdjustTrampoline(builder),
-                llvm::ArrayRef<mlir::Value>{tramp});
-            rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
-                                                   adjustCall.getResult(0));
-            opIsValid = false;
-          } else {
-            // Just forward the function as a pointer.
-            rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
-                                                   embox.getFunc());
-            opIsValid = false;
-          }
-        } else if (auto global = mlir::dyn_cast<GlobalOp>(op)) {
-          auto ty = global.getType();
-          if (typeConverter.needsConversion(ty)) {
-            rewriter.startOpModification(global);
-            auto toTy = typeConverter.convertType(ty);
-            global.setType(toTy);
-            rewriter.finalizeOpModification(global);
-          }
-        } else if (auto mem = mlir::dyn_cast<AllocaOp>(op)) {
-          auto ty = mem.getType();
-          if (typeConverter.needsConversion(ty)) {
-            rewriter.setInsertionPoint(mem);
-            auto toTy = typeConverter.convertType(unwrapRefType(ty));
-            bool isPinned = mem.getPinned();
-            llvm::StringRef uniqName =
-                mem.getUniqName().value_or(llvm::StringRef());
-            llvm::StringRef bindcName =
-                mem.getBindcName().value_or(llvm::StringRef());
-            rewriter.replaceOpWithNewOp<AllocaOp>(
-                mem, toTy, uniqName, bindcName, isPinned, mem.getTypeparams(),
-                mem.getShape());
-            opIsValid = false;
-          }
-        } else if (auto mem = mlir::dyn_cast<AllocMemOp>(op)) {
-          auto ty = mem.getType();
-          if (typeConverter.needsConversion(ty)) {
-            rewriter.setInsertionPoint(mem);
-            auto toTy = typeConverter.convertType(unwrapRefType(ty));
-            llvm::StringRef uniqName =
-                mem.getUniqName().value_or(llvm::StringRef());
-            llvm::StringRef bindcName =
-                mem.getBindcName().value_or(llvm::StringRef());
-            rewriter.replaceOpWithNewOp<AllocMemOp>(
-                mem, toTy, uniqName, bindcName, mem.getTypeparams(),
-                mem.getShape());
-            opIsValid = false;
-          }
-        } else if (auto coor = mlir::dyn_cast<CoordinateOp>(op)) {
-          auto ty = coor.getType();
-          mlir::Type baseTy = coor.getBaseType();
-          if (typeConverter.needsConversion(ty) ||
-              typeConverter.needsConversion(baseTy)) {
-            rewriter.setInsertionPoint(coor);
-            auto toTy = typeConverter.convertType(ty);
-            auto toBaseTy = typeConverter.convertType(baseTy);
-            rewriter.replaceOpWithNewOp<CoordinateOp>(
-                coor, toTy, coor.getRef(), coor.getCoor(), toBaseTy,
-                coor.getFieldIndicesAttr());
-            opIsValid = false;
-          }
-        } else if (auto index = mlir::dyn_cast<FieldIndexOp>(op)) {
-          auto ty = index.getType();
-          mlir::Type onTy = index.getOnType();
-          if (typeConverter.needsConversion(ty) ||
-              typeConverter.needsConversion(onTy)) {
-            rewriter.setInsertionPoint(index);
-            auto toTy = typeConverter.convertType(ty);
-            auto toOnTy = typeConverter.convertType(onTy);
-            rewriter.replaceOpWithNewOp<FieldIndexOp>(
-                index, toTy, index.getFieldId(), toOnTy, 
index.getTypeparams());
-            opIsValid = false;
-          }
-        } else if (auto index = mlir::dyn_cast<LenParamIndexOp>(op)) {
-          auto ty = index.getType();
-          mlir::Type onTy = index.getOnType();
-          if (typeConverter.needsConversion(ty) ||
-              typeConverter.needsConversion(onTy)) {
-            rewriter.setInsertionPoint(index);
-            auto toTy = typeConverter.convertType(ty);
-            auto toOnTy = typeConverter.convertType(onTy);
-            rewriter.replaceOpWithNewOp<LenParamIndexOp>(
-                index, toTy, index.getFieldId(), toOnTy, 
index.getTypeparams());
-            opIsValid = false;
+
+      // When using runtime trampolines, we need to track handles per
+      // function so we can insert FreeTrampoline calls at each return.
+      // Process functions individually to manage this state.
+      if (useRuntimeTrampoline) {
+        getModule().walk([&](mlir::func::FuncOp funcOp) {
+          trampolineHandles.clear();
+          processFunction(funcOp, rewriter, typeConverter);
+          insertTrampolineFrees(funcOp, rewriter);
+        });
+        // Also process non-function ops at module level (globals, etc.)
+        processModuleLevelOps(rewriter, typeConverter);
+      } else {
+        getModule().walk([&](mlir::Operation *op) {
+          processOp(op, rewriter, typeConverter);
+        });
+      }
+    }
+  }
+
+private:
+  /// Trampoline handles collected while processing a function.
+  /// Each entry is a Value representing the opaque handle returned
+  /// by _FortranATrampolineInit, which must be freed before the
+  /// function returns.
+  llvm::SmallVector<mlir::Value> trampolineHandles;
+
+  /// Process all ops within a function.
+  void processFunction(mlir::func::FuncOp funcOp, mlir::IRRewriter &rewriter,
+                       BoxprocTypeRewriter &typeConverter) {
+    funcOp.walk(
+        [&](mlir::Operation *op) { processOp(op, rewriter, typeConverter); });
+  }
+
+  /// Process non-function ops at module level (globals, etc.)
+  void processModuleLevelOps(mlir::IRRewriter &rewriter,
+                             BoxprocTypeRewriter &typeConverter) {
+    for (auto &op : getModule().getBody()->getOperations()) {
+      if (!mlir::isa<mlir::func::FuncOp>(op))
+        processOp(&op, rewriter, typeConverter);
+    }
----------------
vzakhari wrote:

```suggestion
    for (auto &op : getModule().getBody()->getOperations())
      if (!mlir::isa<mlir::func::FuncOp>(op))
        processOp(&op, rewriter, typeConverter);
```

https://github.com/llvm/llvm-project/pull/183108
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to