================
@@ -219,200 +215,359 @@ class BoxedProcedurePass
inline mlir::ModuleOp getModule() { return getOperation(); }
void runOnOperation() override final {
- if (options.useThunks) {
+ if (useThunks) {
auto *context = &getContext();
mlir::IRRewriter rewriter(context);
BoxprocTypeRewriter typeConverter(mlir::UnknownLoc::get(context));
- getModule().walk([&](mlir::Operation *op) {
- bool opIsValid = true;
- typeConverter.setLocation(op->getLoc());
- if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) {
- mlir::Type ty = addr.getVal().getType();
- mlir::Type resTy = addr.getResult().getType();
- if (llvm::isa<mlir::FunctionType>(ty) ||
- llvm::isa<fir::BoxProcType>(ty)) {
- // Rewrite all `fir.box_addr` ops on values of type `!fir.boxproc`
- // or function type to be `fir.convert` ops.
- rewriter.setInsertionPoint(addr);
- rewriter.replaceOpWithNewOp<ConvertOp>(
- addr, typeConverter.convertType(addr.getType()),
addr.getVal());
- opIsValid = false;
- } else if (typeConverter.needsConversion(resTy)) {
- rewriter.startOpModification(op);
- op->getResult(0).setType(typeConverter.convertType(resTy));
- rewriter.finalizeOpModification(op);
- }
- } else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) {
- mlir::FunctionType ty = func.getFunctionType();
- if (typeConverter.needsConversion(ty)) {
- rewriter.startOpModification(func);
- auto toTy =
- mlir::cast<mlir::FunctionType>(typeConverter.convertType(ty));
- if (!func.empty())
- for (auto e : llvm::enumerate(toTy.getInputs())) {
- unsigned i = e.index();
- auto &block = func.front();
- block.insertArgument(i, e.value(), func.getLoc());
- block.getArgument(i + 1).replaceAllUsesWith(
- block.getArgument(i));
- block.eraseArgument(i + 1);
- }
- func.setType(toTy);
- rewriter.finalizeOpModification(func);
+
+ // When using safe trampolines, we need to track handles per
+ // function so we can insert FreeTrampoline calls at each return.
+ // Process functions individually to manage this state.
+ if (useSafeTrampoline) {
+ getModule().walk([&](mlir::func::FuncOp funcOp) {
+ trampolineHandles.clear();
+ trampolineCallableMap.clear();
+ processFunction(funcOp, rewriter, typeConverter);
+ insertTrampolineFrees(funcOp, rewriter);
+ });
+ // Also process non-function ops at module level (globals, etc.)
+ processModuleLevelOps(rewriter, typeConverter);
+ } else {
+ getModule().walk([&](mlir::Operation *op) {
+ processOp(op, rewriter, typeConverter);
+ });
+ }
+ }
+ }
+
+private:
+ /// Trampoline handles collected while processing a function.
+ /// Each entry is a Value representing the opaque handle returned
+ /// by _FortranATrampolineInit, which must be freed before the
+ /// function returns.
+ llvm::SmallVector<mlir::Value> trampolineHandles;
+
+ /// Cache of trampoline callable addresses keyed by the func SSA value
+ /// of the emboxproc. This deduplicates trampolines when the same
+ /// internal procedure is emboxed multiple times in one host function.
+ llvm::DenseMap<mlir::Value, mlir::Value> trampolineCallableMap;
+
+ /// Process all ops within a function.
+ void processFunction(mlir::func::FuncOp funcOp, mlir::IRRewriter &rewriter,
+ BoxprocTypeRewriter &typeConverter) {
+ funcOp.walk(
+ [&](mlir::Operation *op) { processOp(op, rewriter, typeConverter); });
+ }
+
+ /// Process non-function ops at module level (globals, etc.)
+ void processModuleLevelOps(mlir::IRRewriter &rewriter,
+ BoxprocTypeRewriter &typeConverter) {
+ for (auto &op : getModule().getBody()->getOperations())
+ if (!mlir::isa<mlir::func::FuncOp>(op))
+ processOp(&op, rewriter, typeConverter);
+ }
+
+ /// Insert _FortranATrampolineFree calls before every return in the function.
+ void insertTrampolineFrees(mlir::func::FuncOp funcOp,
+ mlir::IRRewriter &rewriter) {
+ if (trampolineHandles.empty())
+ return;
+
+ auto module{funcOp->getParentOfType<mlir::ModuleOp>()};
+ // Insert TrampolineFree calls before every func.return in this function.
+ // At this pass stage (after CFGConversion), func.return is the only
+ // terminator that exits the function. Other terminators are either
+ // intra-function branches (cf.br, cf.cond_br, fir.select*) or
+ // fir.unreachable (after STOP/ERROR STOP), which don't need cleanup
+ // since the process is terminating.
+ funcOp.walk([&](mlir::func::ReturnOp retOp) {
+ rewriter.setInsertionPoint(retOp);
+ FirOpBuilder builder(rewriter, module);
+ auto loc{retOp.getLoc()};
+ for (mlir::Value handle : trampolineHandles)
+ fir::runtime::genTrampolineFree(builder, loc, handle);
+ });
+ }
+
+ /// Process a single operation for boxproc type rewriting.
+ void processOp(mlir::Operation *op, mlir::IRRewriter &rewriter,
+ BoxprocTypeRewriter &typeConverter) {
+ bool opIsValid{true};
+ typeConverter.setLocation(op->getLoc());
+ if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) {
+ mlir::Type ty{addr.getVal().getType()};
+ mlir::Type resTy{addr.getResult().getType()};
+ if (llvm::isa<mlir::FunctionType>(ty) ||
+ llvm::isa<fir::BoxProcType>(ty)) {
+ // Rewrite all `fir.box_addr` ops on values of type `!fir.boxproc`
+ // or function type to be `fir.convert` ops.
+ rewriter.setInsertionPoint(addr);
+ rewriter.replaceOpWithNewOp<ConvertOp>(
+ addr, typeConverter.convertType(addr.getType()), addr.getVal());
+ opIsValid = false;
+ } else if (typeConverter.needsConversion(resTy)) {
+ rewriter.startOpModification(op);
+ op->getResult(0).setType(typeConverter.convertType(resTy));
+ rewriter.finalizeOpModification(op);
+ }
+ } else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) {
+ mlir::FunctionType ty{func.getFunctionType()};
+ if (typeConverter.needsConversion(ty)) {
+ rewriter.startOpModification(func);
+ auto toTy{
+ mlir::cast<mlir::FunctionType>(typeConverter.convertType(ty))};
+ if (!func.empty())
+ for (auto e : llvm::enumerate(toTy.getInputs())) {
+ auto i{static_cast<unsigned>(e.index())};
+ auto &block{func.front()};
+ block.insertArgument(i, e.value(), func.getLoc());
+ block.getArgument(i + 1).replaceAllUsesWith(block.getArgument(i));
+ block.eraseArgument(i + 1);
}
- } else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) {
- // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk
- // as required.
- mlir::Type toTy = typeConverter.convertType(
- mlir::cast<BoxProcType>(embox.getType()).getEleTy());
- rewriter.setInsertionPoint(embox);
- if (embox.getHost()) {
- // Create the thunk.
- auto module = embox->getParentOfType<mlir::ModuleOp>();
- FirOpBuilder builder(rewriter, module);
- const auto triple{fir::getTargetTriple(module)};
- auto loc = embox.getLoc();
- mlir::Type i8Ty = builder.getI8Type();
- mlir::Type i8Ptr = builder.getRefType(i8Ty);
- // For PPC32 and PPC64, the thunk is populated by a call to
- // __trampoline_setup, which is defined in
- // compiler-rt/lib/builtins/trampoline_setup.c and requires the
- // thunk size greater than 32 bytes. For AArch64, RISCV and
x86_64,
- // the thunk setup doesn't go through __trampoline_setup and fits
in
- // 32 bytes.
- fir::SequenceType::Extent thunkSize = triple.getTrampolineSize();
- mlir::Type buffTy = SequenceType::get({thunkSize}, i8Ty);
- auto buffer = AllocaOp::create(builder, loc, buffTy);
- mlir::Value closure =
- builder.createConvert(loc, i8Ptr, embox.getHost());
- mlir::Value tramp = builder.createConvert(loc, i8Ptr, buffer);
- mlir::Value func =
- builder.createConvert(loc, i8Ptr, embox.getFunc());
- fir::CallOp::create(
- builder, loc, factory::getLlvmInitTrampoline(builder),
- llvm::ArrayRef<mlir::Value>{tramp, func, closure});
- auto adjustCall = fir::CallOp::create(
- builder, loc, factory::getLlvmAdjustTrampoline(builder),
- llvm::ArrayRef<mlir::Value>{tramp});
+ func.setType(toTy);
+ rewriter.finalizeOpModification(func);
+ }
+ } else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) {
+ // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk
+ // as required.
+ mlir::Type toTy{typeConverter.convertType(
+ mlir::cast<BoxProcType>(embox.getType()).getEleTy())};
+ rewriter.setInsertionPoint(embox);
+ if (embox.getHost()) {
+ auto module{embox->getParentOfType<mlir::ModuleOp>()};
+ auto loc{embox.getLoc()};
+
+ if (useSafeTrampoline) {
+ // Runtime trampoline pool path (W^X compliant).
+ // Insert Init/Adjust in the function's entry block so the
+ // handle dominates all func.return ops where TrampolineFree
+ // is emitted. This is necessary because fir.emboxproc may
+ // appear inside control flow branches. A cache avoids
+ // creating duplicate trampolines for the same internal
+ // procedure within a single host function.
+ mlir::Value funcVal{embox.getFunc()};
+ auto cacheIt{trampolineCallableMap.find(funcVal)};
+ if (cacheIt != trampolineCallableMap.end()) {
rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
- adjustCall.getResult(0));
- opIsValid = false;
+ cacheIt->second);
} else {
- // Just forward the function as a pointer.
- rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
- embox.getFunc());
- opIsValid = false;
- }
- } else if (auto global = mlir::dyn_cast<GlobalOp>(op)) {
- auto ty = global.getType();
- if (typeConverter.needsConversion(ty)) {
- rewriter.startOpModification(global);
- auto toTy = typeConverter.convertType(ty);
- global.setType(toTy);
- rewriter.finalizeOpModification(global);
- }
- } else if (auto mem = mlir::dyn_cast<AllocaOp>(op)) {
- auto ty = mem.getType();
- if (typeConverter.needsConversion(ty)) {
- rewriter.setInsertionPoint(mem);
- auto toTy = typeConverter.convertType(unwrapRefType(ty));
- bool isPinned = mem.getPinned();
- llvm::StringRef uniqName =
- mem.getUniqName().value_or(llvm::StringRef());
- llvm::StringRef bindcName =
- mem.getBindcName().value_or(llvm::StringRef());
- rewriter.replaceOpWithNewOp<AllocaOp>(
- mem, toTy, uniqName, bindcName, isPinned, mem.getTypeparams(),
- mem.getShape());
- opIsValid = false;
- }
- } else if (auto mem = mlir::dyn_cast<AllocMemOp>(op)) {
- auto ty = mem.getType();
- if (typeConverter.needsConversion(ty)) {
- rewriter.setInsertionPoint(mem);
- auto toTy = typeConverter.convertType(unwrapRefType(ty));
- llvm::StringRef uniqName =
- mem.getUniqName().value_or(llvm::StringRef());
- llvm::StringRef bindcName =
- mem.getBindcName().value_or(llvm::StringRef());
- rewriter.replaceOpWithNewOp<AllocMemOp>(
- mem, toTy, uniqName, bindcName, mem.getTypeparams(),
- mem.getShape());
- opIsValid = false;
- }
- } else if (auto coor = mlir::dyn_cast<CoordinateOp>(op)) {
- auto ty = coor.getType();
- mlir::Type baseTy = coor.getBaseType();
- if (typeConverter.needsConversion(ty) ||
- typeConverter.needsConversion(baseTy)) {
- rewriter.setInsertionPoint(coor);
- auto toTy = typeConverter.convertType(ty);
- auto toBaseTy = typeConverter.convertType(baseTy);
- rewriter.replaceOpWithNewOp<CoordinateOp>(
- coor, toTy, coor.getRef(), coor.getCoor(), toBaseTy,
- coor.getFieldIndicesAttr());
- opIsValid = false;
- }
- } else if (auto index = mlir::dyn_cast<FieldIndexOp>(op)) {
- auto ty = index.getType();
- mlir::Type onTy = index.getOnType();
- if (typeConverter.needsConversion(ty) ||
- typeConverter.needsConversion(onTy)) {
- rewriter.setInsertionPoint(index);
- auto toTy = typeConverter.convertType(ty);
- auto toOnTy = typeConverter.convertType(onTy);
- rewriter.replaceOpWithNewOp<FieldIndexOp>(
- index, toTy, index.getFieldId(), toOnTy,
index.getTypeparams());
- opIsValid = false;
- }
- } else if (auto index = mlir::dyn_cast<LenParamIndexOp>(op)) {
- auto ty = index.getType();
- mlir::Type onTy = index.getOnType();
- if (typeConverter.needsConversion(ty) ||
- typeConverter.needsConversion(onTy)) {
- rewriter.setInsertionPoint(index);
- auto toTy = typeConverter.convertType(ty);
- auto toOnTy = typeConverter.convertType(onTy);
- rewriter.replaceOpWithNewOp<LenParamIndexOp>(
- index, toTy, index.getFieldId(), toOnTy,
index.getTypeparams());
- opIsValid = false;
- }
- } else {
- rewriter.startOpModification(op);
- // Convert the operands if needed
- for (auto i : llvm::enumerate(op->getResultTypes()))
- if (typeConverter.needsConversion(i.value())) {
- auto toTy = typeConverter.convertType(i.value());
- op->getResult(i.index()).setType(toTy);
+ auto parentFunc{embox->getParentOfType<mlir::func::FuncOp>()};
+ auto &entryBlock{parentFunc.front()};
+
+ auto savedIP{rewriter.saveInsertionPoint()};
+
+ // Find the right insertion point in the entry block.
+ // Walk up from the emboxproc to find its top-level
+ // ancestor in the entry block. For an emboxproc directly
+ // in the entry block, this is the emboxproc itself.
+ // For one inside a structured op (fir.if, fir.do_loop),
+ // this is that structured op. For one inside an explicit
+ // branch target (cf.cond_br → ^bb1), we fall back to the
+ // entry block terminator.
+ mlir::Operation *entryAncestor{embox.getOperation()};
+ while (entryAncestor->getBlock() != &entryBlock) {
+ entryAncestor = entryAncestor->getParentOp();
+ if (!entryAncestor ||
+ mlir::isa<mlir::func::FuncOp>(entryAncestor))
+ break;
}
+ bool ancestorInEntry{
+ entryAncestor &&
+ !mlir::isa<mlir::func::FuncOp>(entryAncestor) &&
+ entryAncestor->getBlock() == &entryBlock};
- // Convert the type attributes if needed
- for (const mlir::NamedAttribute &attr : op->getAttrDictionary())
- if (auto tyAttr = llvm::dyn_cast<mlir::TypeAttr>(attr.getValue()))
- if (typeConverter.needsConversion(tyAttr.getValue())) {
- auto toTy = typeConverter.convertType(tyAttr.getValue());
- op->setAttr(attr.getName(), mlir::TypeAttr::get(toTy));
+ // If the func value is not in the entry block (e.g.,
+ // address_of generated inside a structured fir.if),
+ // clone it into the entry block.
+ mlir::Value funcValInEntry{funcVal};
+ if (auto *funcDef{funcVal.getDefiningOp()}) {
+ if (funcDef->getBlock() != &entryBlock) {
+ if (ancestorInEntry)
+ rewriter.setInsertionPoint(entryAncestor);
+ else
+ rewriter.setInsertionPoint(entryBlock.getTerminator());
+ auto *cloned{rewriter.clone(*funcDef)};
+ funcValInEntry = cloned->getResult(0);
}
- rewriter.finalizeOpModification(op);
+ }
+
+ // Similarly clone the host value if not in entry block.
+ mlir::Value hostValInEntry{embox.getHost()};
+ if (auto *hostDef{embox.getHost().getDefiningOp()}) {
+ if (hostDef->getBlock() != &entryBlock) {
+ if (ancestorInEntry)
+ rewriter.setInsertionPoint(entryAncestor);
+ else
+ rewriter.setInsertionPoint(entryBlock.getTerminator());
+ auto *cloned{rewriter.clone(*hostDef)};
+ hostValInEntry = cloned->getResult(0);
+ }
----------------
jeanPerier wrote:
Thanks, this looks good to me. Thinking more about it, the invariant will be
broken if the MLIR inliner is ran before this and a host procedure is inlined
into some other function body, potentially not in the entry block.
I do not like the idea of trying to clone the IR for the host link creation to
solve this. It may not be possible if it depends on some calls anyway. I think
the best solution to this is probably to add an operation in lowering to deal
with the trampoline cleanups when needed (with some mechanism to connect the
fir.embox_proc to it).
Since MLIR inlining is experimental in flang, I think this is OK to proceed
with your patch, but we should be aware that the current approach risk hitting
a compilation error when combined with the `-mllvm -inline-all` developer
option.
https://github.com/llvm/llvm-project/pull/183108
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits