================
@@ -219,200 +215,284 @@ class BoxedProcedurePass
inline mlir::ModuleOp getModule() { return getOperation(); }
void runOnOperation() override final {
- if (options.useThunks) {
+ if (useThunks) {
auto *context = &getContext();
mlir::IRRewriter rewriter(context);
BoxprocTypeRewriter typeConverter(mlir::UnknownLoc::get(context));
- getModule().walk([&](mlir::Operation *op) {
- bool opIsValid = true;
- typeConverter.setLocation(op->getLoc());
- if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) {
- mlir::Type ty = addr.getVal().getType();
- mlir::Type resTy = addr.getResult().getType();
- if (llvm::isa<mlir::FunctionType>(ty) ||
- llvm::isa<fir::BoxProcType>(ty)) {
- // Rewrite all `fir.box_addr` ops on values of type `!fir.boxproc`
- // or function type to be `fir.convert` ops.
- rewriter.setInsertionPoint(addr);
- rewriter.replaceOpWithNewOp<ConvertOp>(
- addr, typeConverter.convertType(addr.getType()),
addr.getVal());
- opIsValid = false;
- } else if (typeConverter.needsConversion(resTy)) {
- rewriter.startOpModification(op);
- op->getResult(0).setType(typeConverter.convertType(resTy));
- rewriter.finalizeOpModification(op);
- }
- } else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) {
- mlir::FunctionType ty = func.getFunctionType();
- if (typeConverter.needsConversion(ty)) {
- rewriter.startOpModification(func);
- auto toTy =
- mlir::cast<mlir::FunctionType>(typeConverter.convertType(ty));
- if (!func.empty())
- for (auto e : llvm::enumerate(toTy.getInputs())) {
- unsigned i = e.index();
- auto &block = func.front();
- block.insertArgument(i, e.value(), func.getLoc());
- block.getArgument(i + 1).replaceAllUsesWith(
- block.getArgument(i));
- block.eraseArgument(i + 1);
- }
- func.setType(toTy);
- rewriter.finalizeOpModification(func);
- }
- } else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) {
- // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk
- // as required.
- mlir::Type toTy = typeConverter.convertType(
- mlir::cast<BoxProcType>(embox.getType()).getEleTy());
- rewriter.setInsertionPoint(embox);
- if (embox.getHost()) {
- // Create the thunk.
- auto module = embox->getParentOfType<mlir::ModuleOp>();
- FirOpBuilder builder(rewriter, module);
- const auto triple{fir::getTargetTriple(module)};
- auto loc = embox.getLoc();
- mlir::Type i8Ty = builder.getI8Type();
- mlir::Type i8Ptr = builder.getRefType(i8Ty);
- // For PPC32 and PPC64, the thunk is populated by a call to
- // __trampoline_setup, which is defined in
- // compiler-rt/lib/builtins/trampoline_setup.c and requires the
- // thunk size greater than 32 bytes. For AArch64, RISCV and
x86_64,
- // the thunk setup doesn't go through __trampoline_setup and fits
in
- // 32 bytes.
- fir::SequenceType::Extent thunkSize = triple.getTrampolineSize();
- mlir::Type buffTy = SequenceType::get({thunkSize}, i8Ty);
- auto buffer = AllocaOp::create(builder, loc, buffTy);
- mlir::Value closure =
- builder.createConvert(loc, i8Ptr, embox.getHost());
- mlir::Value tramp = builder.createConvert(loc, i8Ptr, buffer);
- mlir::Value func =
- builder.createConvert(loc, i8Ptr, embox.getFunc());
- fir::CallOp::create(
- builder, loc, factory::getLlvmInitTrampoline(builder),
- llvm::ArrayRef<mlir::Value>{tramp, func, closure});
- auto adjustCall = fir::CallOp::create(
- builder, loc, factory::getLlvmAdjustTrampoline(builder),
- llvm::ArrayRef<mlir::Value>{tramp});
- rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
- adjustCall.getResult(0));
- opIsValid = false;
- } else {
- // Just forward the function as a pointer.
- rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
- embox.getFunc());
- opIsValid = false;
- }
- } else if (auto global = mlir::dyn_cast<GlobalOp>(op)) {
- auto ty = global.getType();
- if (typeConverter.needsConversion(ty)) {
- rewriter.startOpModification(global);
- auto toTy = typeConverter.convertType(ty);
- global.setType(toTy);
- rewriter.finalizeOpModification(global);
- }
- } else if (auto mem = mlir::dyn_cast<AllocaOp>(op)) {
- auto ty = mem.getType();
- if (typeConverter.needsConversion(ty)) {
- rewriter.setInsertionPoint(mem);
- auto toTy = typeConverter.convertType(unwrapRefType(ty));
- bool isPinned = mem.getPinned();
- llvm::StringRef uniqName =
- mem.getUniqName().value_or(llvm::StringRef());
- llvm::StringRef bindcName =
- mem.getBindcName().value_or(llvm::StringRef());
- rewriter.replaceOpWithNewOp<AllocaOp>(
- mem, toTy, uniqName, bindcName, isPinned, mem.getTypeparams(),
- mem.getShape());
- opIsValid = false;
- }
- } else if (auto mem = mlir::dyn_cast<AllocMemOp>(op)) {
- auto ty = mem.getType();
- if (typeConverter.needsConversion(ty)) {
- rewriter.setInsertionPoint(mem);
- auto toTy = typeConverter.convertType(unwrapRefType(ty));
- llvm::StringRef uniqName =
- mem.getUniqName().value_or(llvm::StringRef());
- llvm::StringRef bindcName =
- mem.getBindcName().value_or(llvm::StringRef());
- rewriter.replaceOpWithNewOp<AllocMemOp>(
- mem, toTy, uniqName, bindcName, mem.getTypeparams(),
- mem.getShape());
- opIsValid = false;
- }
- } else if (auto coor = mlir::dyn_cast<CoordinateOp>(op)) {
- auto ty = coor.getType();
- mlir::Type baseTy = coor.getBaseType();
- if (typeConverter.needsConversion(ty) ||
- typeConverter.needsConversion(baseTy)) {
- rewriter.setInsertionPoint(coor);
- auto toTy = typeConverter.convertType(ty);
- auto toBaseTy = typeConverter.convertType(baseTy);
- rewriter.replaceOpWithNewOp<CoordinateOp>(
- coor, toTy, coor.getRef(), coor.getCoor(), toBaseTy,
- coor.getFieldIndicesAttr());
- opIsValid = false;
- }
- } else if (auto index = mlir::dyn_cast<FieldIndexOp>(op)) {
- auto ty = index.getType();
- mlir::Type onTy = index.getOnType();
- if (typeConverter.needsConversion(ty) ||
- typeConverter.needsConversion(onTy)) {
- rewriter.setInsertionPoint(index);
- auto toTy = typeConverter.convertType(ty);
- auto toOnTy = typeConverter.convertType(onTy);
- rewriter.replaceOpWithNewOp<FieldIndexOp>(
- index, toTy, index.getFieldId(), toOnTy,
index.getTypeparams());
- opIsValid = false;
- }
- } else if (auto index = mlir::dyn_cast<LenParamIndexOp>(op)) {
- auto ty = index.getType();
- mlir::Type onTy = index.getOnType();
- if (typeConverter.needsConversion(ty) ||
- typeConverter.needsConversion(onTy)) {
- rewriter.setInsertionPoint(index);
- auto toTy = typeConverter.convertType(ty);
- auto toOnTy = typeConverter.convertType(onTy);
- rewriter.replaceOpWithNewOp<LenParamIndexOp>(
- index, toTy, index.getFieldId(), toOnTy,
index.getTypeparams());
- opIsValid = false;
+
+ // When using runtime trampolines, we need to track handles per
+ // function so we can insert FreeTrampoline calls at each return.
+ // Process functions individually to manage this state.
+ if (useRuntimeTrampoline) {
+ getModule().walk([&](mlir::func::FuncOp funcOp) {
+ trampolineHandles.clear();
+ processFunction(funcOp, rewriter, typeConverter);
+ insertTrampolineFrees(funcOp, rewriter);
+ });
+ // Also process non-function ops at module level (globals, etc.)
+ processModuleLevelOps(rewriter, typeConverter);
+ } else {
+ getModule().walk([&](mlir::Operation *op) {
+ processOp(op, rewriter, typeConverter);
+ });
+ }
+ }
+ }
+
+private:
+ /// Trampoline handles collected while processing a function.
+ /// Each entry is a Value representing the opaque handle returned
+ /// by _FortranATrampolineInit, which must be freed before the
+ /// function returns.
+ llvm::SmallVector<mlir::Value> trampolineHandles;
+
+ /// Process all ops within a function.
+ void processFunction(mlir::func::FuncOp funcOp, mlir::IRRewriter &rewriter,
+ BoxprocTypeRewriter &typeConverter) {
+ funcOp.walk(
+ [&](mlir::Operation *op) { processOp(op, rewriter, typeConverter); });
+ }
+
+ /// Process non-function ops at module level (globals, etc.)
+ void processModuleLevelOps(mlir::IRRewriter &rewriter,
+ BoxprocTypeRewriter &typeConverter) {
+ for (auto &op : getModule().getBody()->getOperations()) {
+ if (!mlir::isa<mlir::func::FuncOp>(op))
+ processOp(&op, rewriter, typeConverter);
+ }
----------------
vzakhari wrote:
```suggestion
for (auto &op : getModule().getBody()->getOperations())
if (!mlir::isa<mlir::func::FuncOp>(op))
processOp(&op, rewriter, typeConverter);
```
https://github.com/llvm/llvm-project/pull/183108
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits