================ @@ -0,0 +1,228 @@ +//===- CallConvLoweringPass.cpp - Lower CIR to ABI calling convention ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass walks every cir.func and cir.call in the module, computes a +// FunctionClassification for it (via either an ABI target or a pre-built +// classification injected as a function attribute), and dispatches to +// CIRABIRewriteContext to perform the actual IR rewriting. +// +// Two driver modes (mutually exclusive): +// +// target=test +// Use the MLIR test ABI target (mlir/lib/ABI/Targets/Test/) to classify +// each function. Predictable rules that approximate x86_64 SysV. Real +// targets (x86_64, AArch64) will be added once the LLVM ABI library +// ships them. +// +// classification-attr=<name> +// Read a DictionaryAttr named <name> from each cir.func and parse it via +// mlir::abi::test::parseClassificationAttr. Used by tests to inject any +// classification (including shapes the test target itself does not +// produce) without depending on a real ABI target. +// +// The pass requires a `dlti.dl_spec` attribute on the module so the +// classifier can query type sizes and alignments. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "TargetLowering/CIRABIRewriteContext.h" + +#include "mlir/ABI/ABIRewriteContext.h" +#include "mlir/ABI/Targets/Test/TestTarget.h" +#include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "clang/CIR/Dialect/IR/CIRDialect.h" +#include "clang/CIR/Dialect/Passes.h" + +using namespace mlir; +using namespace mlir::abi; +using namespace cir; + +namespace mlir { +#define GEN_PASS_DEF_CALLCONVLOWERING +#include "clang/CIR/Dialect/Passes.h.inc" +} // namespace mlir + +namespace { + +bool needsRewrite(const FunctionClassification &fc) { + if ((fc.returnInfo.kind != ArgKind::Direct) || fc.returnInfo.coercedType) + return true; + for (const ArgClassification &ac : fc.argInfos) + if ((ac.kind != ArgKind::Direct) || ac.coercedType) + return true; + return false; +} + +struct CallConvLoweringPass + : public impl::CallConvLoweringBase<CallConvLoweringPass> { + using CallConvLoweringBase::CallConvLoweringBase; + void runOnOperation() override; +}; + +/// Classify \p func using whichever driver mode is configured. Returns +/// std::nullopt and emits an error on the function if classification fails +/// (e.g. injection-driver mode but the function is missing the attribute, +/// or the attribute is malformed). +std::optional<FunctionClassification> +classifyFunction(cir::FuncOp func, const DataLayout &dl, StringRef target, + StringRef classificationAttrName) { + ArrayRef<Type> argTypes = func.getFunctionType().getInputs(); + Type returnType = func.getFunctionType().getReturnType(); + + if (!classificationAttrName.empty()) { + auto attr = func->getAttrOfType<DictionaryAttr>(classificationAttrName); + if (!attr) { + func.emitOpError() + << "missing classification attribute '" << classificationAttrName + << "' (CallConvLowering driver mode 'classification-attr')"; + return std::nullopt; + } + return mlir::abi::test::parseClassificationAttr( + attr, [&]() { return func.emitOpError(); }); + } + + if (target == "test") + return mlir::abi::test::classify(argTypes, returnType, dl); + + func.emitOpError() << "unknown target '" << target << "' (supported: test)"; + return std::nullopt; +} + +/// Find the cir.func declaration matching a direct cir.call / cir.try_call +/// callee, if any. Returns nullptr if the callee is indirect or the symbol +/// cannot be resolved. Takes a SymbolTable instead of a ModuleOp so the +/// symbol lookup is amortized across all the call sites the driver walks +/// (ModuleOp::lookupSymbol is linear per call). +cir::FuncOp lookupCallee(Operation *callOp, SymbolTable &symbolTable) { + FlatSymbolRefAttr callee; + if (auto call = dyn_cast<cir::CallOp>(callOp)) + callee = call.getCalleeAttr(); + else if (auto tryCall = dyn_cast<cir::TryCallOp>(callOp)) + callee = tryCall.getCalleeAttr(); + else + return nullptr; + if (!callee) + return nullptr; + return symbolTable.lookup<cir::FuncOp>(callee.getValue()); +} + +void CallConvLoweringPass::runOnOperation() { + ModuleOp moduleOp = getOperation(); + MLIRContext *ctx = &getContext(); + + if (target.empty() == classificationAttr.empty()) { + moduleOp.emitOpError() << "CallConvLowering requires exactly one of " + "'target' or 'classification-attr' pass options"; + signalPassFailure(); + return; + } + + if (!moduleOp->hasAttr(DLTIDialect::kDataLayoutAttrName)) { + moduleOp.emitOpError() + << "CallConvLowering requires a DataLayout (dlti.dl_spec attribute " + "on the module)"; + signalPassFailure(); + return; + } + + DataLayout dl(moduleOp); + CIRABIRewriteContext rewriteCtx(moduleOp); + SymbolTable symbolTable(moduleOp); + + // Phase 1: classify every cir.func. No IR mutation happens here, so + // running this as a single up-front walk lets later phases consult any + // function's classification regardless of visitation order. + llvm::MapVector<cir::FuncOp, FunctionClassification> classifications; + bool anyFailed = false; + moduleOp.walk([&](cir::FuncOp f) { + auto fc = classifyFunction(f, dl, target, classificationAttr); + if (!fc) { + anyFailed = true; + return; + } + classifications.insert({f, std::move(*fc)}); + }); + if (anyFailed) { + signalPassFailure(); + return; + } + + // Phase 2: build a callee -> callers index. A single module walk gives + // us every direct cir.call / cir.try_call to each cir.func; we use this + // in phase 3 to rewrite a function and all of its call sites together. + // Indirect or unresolved callees are skipped here (rewriteCallSite + // rejects indirect calls; see phase 4). + llvm::DenseMap<cir::FuncOp, SmallVector<Operation *>> callers; + moduleOp.walk([&](Operation *op) { + if (!isa<cir::CallOp, cir::TryCallOp>(op)) + return; + if (cir::FuncOp callee = lookupCallee(op, symbolTable)) + callers[callee].push_back(op); + }); + + // Phase 3: rewrite each function together with every direct call to + // it. By the time we move on to function F+1, F's signature and every + // direct call to F have already been brought into alignment, and + // F+1..FN are still in their original (mutually consistent) form, so + // the IR is verifier-clean at every outer-iteration boundary. + // + // There is still a brief inner window where F's signature has been + // rewritten but its callers have not yet caught up -- the MLIR + // rewriter API gives us no way to mutate both sides of a call + // atomically. No verifier runs inside the pass, and at pass exit the + // module is verifier-clean. Fusing the inner loop here is what keeps + // the invalid window per-function rather than module-wide. + OpBuilder rewriter(ctx); + for (auto &kv : classifications) { + cir::FuncOp func = kv.first; + const FunctionClassification &fc = kv.second; + if (failed(rewriteCtx.rewriteFunctionDefinition(func, fc, rewriter))) { + signalPassFailure(); + return; + } + for (Operation *callOp : callers.lookup(func)) + if (failed(rewriteCtx.rewriteCallSite(callOp, fc, rewriter))) { + signalPassFailure(); + return; + } + } + + // Phase 4: reject indirect calls when the module contains any ABI rewrite + // that would need call-site lowering. We cannot strip or coerce operands + // without a resolved callee symbol. + const FunctionClassification *rewriteFc = nullptr; + for (auto &kv : classifications) ---------------- andykaylor wrote:
This needs braces on the for-loop to be consistent with the LLVM coding standards since the statement in the body uses braces. https://github.com/llvm/llvm-project/pull/195737 _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
