Author: Mogball Date: 2022-06-29T17:03:14-07:00 New Revision: 81b151bcb52b5d8a6f505c333e63dce386250a08
URL: https://github.com/llvm/llvm-project/commit/81b151bcb52b5d8a6f505c333e63dce386250a08 DIFF: https://github.com/llvm/llvm-project/commit/81b151bcb52b5d8a6f505c333e63dce386250a08.diff LOG: test composition Added: Modified: mlir/test/lib/Analysis/TestDataFlowFramework.cpp Removed: ################################################################################ diff --git a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp index 2fd41b25b765..aaea25508637 100644 --- a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp +++ b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Analysis/SparseDataFlowAnalysis.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" @@ -182,8 +183,80 @@ void TestFooAnalysisPass::runOnOperation() { }); } +namespace { +struct AugmentSCP : public DataFlowAnalysis { + using DataFlowAnalysis::DataFlowAnalysis; + + LogicalResult initialize(Operation *top) override { + top->walk([&](Operation *op) { + if (op->getName().getStringRef() == "test.scp_region") + (void)visit(op); + }); + return success(); + } + + LogicalResult visit(ProgramPoint point) override { + auto *op = point.get<Operation *>(); + assert(op->getName().getStringRef() == "test.scp_region"); + + auto *rhs = getOrCreateFor<ConstantValueState>(op, op->getOperand(0)); + if (rhs->isUninitialized()) return success(); + + for (Region ®ion : op->getRegions()) { + for (Value value : region.getArguments()) { + assert(staticallyProvides(TypeID::get<ConstantValueState>(), value)); + update<ConstantValueState>( + value, [rhs](ConstantValueState *lhs) { return lhs->join(*rhs); }); + } + } + return success(); + } + + bool staticallyProvides(TypeID stateID, ProgramPoint point) const override { + if (stateID != TypeID::get<ConstantValueState>()) + return false; + + auto value = point.dyn_cast<Value>(); + if (!value || !value.isa<BlockArgument>() || + value.getParentBlock() != &value.getParentRegion()->front()) + return false; + + return value.getParentRegion()->getParentOp()->getName().getStringRef() == + "test.scp_region"; + } +}; + +struct AugmentSCPPass : public PassWrapper<AugmentSCPPass, OperationPass<>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AugmentSCPPass) + + StringRef getArgument() const override { return "test-augment-scp"; } + + void runOnOperation() override { + DataFlowSolver solver; + solver.load<DeadCodeAnalysis>(); + solver.load<SparseConstantPropagation>(); + solver.load<AugmentSCP>(); + if (failed(solver.initializeAndRun(getOperation()))) + return signalPassFailure(); + + getOperation()->walk([&](Operation *op) { + for (auto &result : llvm::enumerate(op->getResults())) { + auto *cv = solver.lookup<ConstantValueState>(result.value()); + if (!cv || cv->isUninitialized() || !cv->getValue().getConstantValue()) + continue; + llvm::errs() << "op " << op->getName() << " result #" << result.index() + << " -> " << cv->getValue().getConstantValue() << "\n"; + } + }); + } +}; +} // end anonymous namespace + namespace mlir { namespace test { -void registerTestFooAnalysisPass() { PassRegistration<TestFooAnalysisPass>(); } -} // namespace test -} // namespace mlir +void registerTestFooAnalysisPass() { + PassRegistration<TestFooAnalysisPass>(); + PassRegistration<AugmentSCPPass>(); +} +} // end namespace test +} // end namespace mlir _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits