https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/175815
RFC: https://discourse.llvm.org/t/rfc-simplify-regionbranchopinterface-separate-successor-inputs-from-region-successor/89420/7 Depends on #174945. >From 4976ae49712431d4dfe89ff590c30a35a411d0f3 Mon Sep 17 00:00:00 2001 From: Matthias Springer <[email protected]> Date: Tue, 13 Jan 2026 19:13:34 +0000 Subject: [PATCH] [mlir][Interfaces] Split successor inputs from region successor --- .../Analysis/DataFlow/IntegerRangeAnalysis.h | 1 + .../mlir/Analysis/DataFlow/SparseAnalysis.h | 12 ++-- .../mlir/Dialect/Affine/IR/AffineOps.td | 5 +- .../include/mlir/Dialect/Async/IR/AsyncOps.td | 1 + mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 3 +- mlir/include/mlir/Dialect/GPU/IR/GPUOps.td | 3 +- .../mlir/Dialect/MemRef/IR/MemRefOps.td | 3 +- .../mlir/Dialect/OpenACC/OpenACCOps.td | 21 ++++-- mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 12 ++-- .../include/mlir/Dialect/Shape/IR/ShapeOps.td | 2 +- .../SparseTensor/IR/SparseTensorOps.td | 2 +- .../mlir/Dialect/Transform/IR/TransformOps.td | 8 ++- .../TuneExtension/TuneExtensionOps.td | 3 +- .../mlir/Interfaces/ControlFlowInterfaces.h | 23 ++---- .../mlir/Interfaces/ControlFlowInterfaces.td | 10 +++ .../AliasAnalysis/LocalAliasAnalysis.cpp | 8 +-- .../Analysis/DataFlow/DeadCodeAnalysis.cpp | 3 +- .../DataFlow/IntegerRangeAnalysis.cpp | 6 +- mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp | 25 +++---- mlir/lib/Analysis/SliceWalk.cpp | 2 +- mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 36 +++++++--- mlir/lib/Dialect/Async/IR/Async.cpp | 11 ++- .../OwnershipBasedBufferDeallocation.cpp | 8 +-- mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 17 +++-- mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 7 +- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 8 ++- mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 48 ++++++++++++- mlir/lib/Dialect/SCF/IR/SCF.cpp | 70 +++++++++++++------ mlir/lib/Dialect/Shape/IR/Shape.cpp | 8 ++- .../SparseTensor/IR/SparseTensorDialect.cpp | 10 ++- .../lib/Dialect/Transform/IR/TransformOps.cpp | 40 ++++++++--- .../TuneExtension/TuneExtensionOps.cpp | 14 ++-- mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 6 +- mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 60 ++++++++++++---- mlir/test/lib/Dialect/Test/TestOps.td | 13 ++-- .../Interfaces/ControlFlowInterfacesTest.cpp | 18 ++++- 36 files changed, 371 insertions(+), 156 deletions(-) diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h index 4975cedb282e4..e549a56a6f960 100644 --- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h @@ -68,6 +68,7 @@ class IntegerRangeAnalysis /// known bounds. void visitNonControlFlowArguments(Operation *op, const RegionSuccessor &successor, + ValueRange successorInputs, ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) override; }; diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h index 1bb42a246b701..02f699de06f99 100644 --- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h @@ -215,7 +215,8 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis { /// of loops). virtual void visitNonControlFlowArgumentsImpl( Operation *op, const RegionSuccessor &successor, - ArrayRef<AbstractSparseLattice *> argLattices, unsigned firstIndex) = 0; + ValueRange successorInputs, ArrayRef<AbstractSparseLattice *> argLattices, + unsigned firstIndex) = 0; /// Get the lattice element of a value. virtual AbstractSparseLattice *getLatticeElement(Value value) = 0; @@ -328,11 +329,12 @@ class SparseForwardDataFlowAnalysis /// index of the first element of `argLattices` that is set by control-flow. virtual void visitNonControlFlowArguments(Operation *op, const RegionSuccessor &successor, + ValueRange successorInputs, ArrayRef<StateT *> argLattices, unsigned firstIndex) { setAllToEntryStates(argLattices.take_front(firstIndex)); - setAllToEntryStates(argLattices.drop_front( - firstIndex + successor.getSuccessorInputs().size())); + setAllToEntryStates( + argLattices.drop_front(firstIndex + successorInputs.size())); } protected: @@ -383,10 +385,10 @@ class SparseForwardDataFlowAnalysis } void visitNonControlFlowArgumentsImpl( Operation *op, const RegionSuccessor &successor, - ArrayRef<AbstractSparseLattice *> argLattices, + ValueRange successorInputs, ArrayRef<AbstractSparseLattice *> argLattices, unsigned firstIndex) override { visitNonControlFlowArguments( - op, successor, + op, successor, successorInputs, {reinterpret_cast<StateT *const *>(argLattices.begin()), argLattices.size()}, firstIndex); diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td index bd14f6ff4c5aa..482987ebab27d 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -128,7 +128,7 @@ def AffineForOp : Affine_Op<"for", "getLoopUpperBounds", "getYieldedValuesMutable", "replaceWithAdditionalYields"]>, DeclareOpInterfaceMethods<RegionBranchOpInterface, - ["getEntrySuccessorOperands"]>]> { + ["getEntrySuccessorOperands", "getSuccessorInputs"]>]> { let summary = "for operation"; let description = [{ Syntax: @@ -340,7 +340,8 @@ def AffineForOp : Affine_Op<"for", def AffineIfOp : Affine_Op<"if", [ImplicitAffineTerminator, RecursivelySpeculatable, RecursiveMemoryEffects, NoRegionArguments, - DeclareOpInterfaceMethods<RegionBranchOpInterface> + DeclareOpInterfaceMethods<RegionBranchOpInterface, + ["getSuccessorInputs"]> ]> { let summary = "if-then-else operation"; let description = [{ diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td index b52f13697f0dc..2cebeac767f29 100644 --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -36,6 +36,7 @@ def Async_ExecuteOp : Async_Op<"execute", [SingleBlockImplicitTerminator<"YieldOp">, DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getEntrySuccessorOperands", + "getSuccessorInputs", "areTypesCompatible"]>, AttrSizedOperandSegments, AutomaticAllocationScope, diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index c1820904f2665..caed3233f62e9 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -1470,7 +1470,8 @@ def EmitC_YieldOp : EmitC_Op<"yield", def EmitC_IfOp : EmitC_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [ "getNumRegionInvocations", "getRegionInvocationBounds", - "getEntrySuccessorRegions"]>, OpAsmOpInterface, SingleBlock, + "getEntrySuccessorRegions", "getSuccessorInputs"]>, + OpAsmOpInterface, SingleBlock, SingleBlockImplicitTerminator<"emitc::YieldOp">, RecursiveMemoryEffects, NoRegionArguments]> { let summary = "If-then-else operation"; diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td index e8c23200547d6..b00f69b79d12c 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -3080,7 +3080,8 @@ def GPU_SetCsrPointersOp : GPU_Op<"set_csr_pointers", [GPU_AsyncOpInterface]> { } def GPU_WarpExecuteOnLane0Op : GPU_Op<"warp_execute_on_lane_0", - [DeclareOpInterfaceMethods<RegionBranchOpInterface, ["areTypesCompatible"]>, + [DeclareOpInterfaceMethods<RegionBranchOpInterface, [ + "areTypesCompatible", "getSuccessorInputs"]>, SingleBlockImplicitTerminator<"gpu::YieldOp">, RecursiveMemoryEffects]> { let summary = "Executes operations in the associated region on thread #0 of a" diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index 45122788bd2d4..bd96bace7994f 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -420,7 +420,8 @@ def MemRef_AllocaOp : AllocLikeOp<"alloca", AutomaticAllocationScopeResource,[ def MemRef_AllocaScopeOp : MemRef_Op<"alloca_scope", [AutomaticAllocationScope, - DeclareOpInterfaceMethods<RegionBranchOpInterface>, + DeclareOpInterfaceMethods<RegionBranchOpInterface, [ + "getSuccessorInputs"]>, SingleBlockImplicitTerminator<"AllocaScopeReturnOp">, RecursiveMemoryEffects, NoRegionArguments]> { diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index 644d1f8e9e649..11ee5f0b1088f 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -1684,7 +1684,8 @@ def OpenACC_ParallelOp [AttrSizedOperandSegments, AutomaticAllocationScope, RecursiveMemoryEffects, DeclareOpInterfaceMethods<ComputeRegionOpInterface>, - DeclareOpInterfaceMethods<RegionBranchOpInterface>, + DeclareOpInterfaceMethods<RegionBranchOpInterface, + ["getSuccessorInputs"]>, OffloadRegionOpInterface, MemoryEffects<[MemWrite<OpenACC_ConstructResource>, MemRead<OpenACC_CurrentDeviceIdResource>]>]> { @@ -1885,7 +1886,8 @@ def OpenACC_SerialOp [AttrSizedOperandSegments, AutomaticAllocationScope, RecursiveMemoryEffects, DeclareOpInterfaceMethods<ComputeRegionOpInterface>, - DeclareOpInterfaceMethods<RegionBranchOpInterface>, + DeclareOpInterfaceMethods<RegionBranchOpInterface, + ["getSuccessorInputs"]>, OffloadRegionOpInterface, MemoryEffects<[MemWrite<OpenACC_ConstructResource>, MemRead<OpenACC_CurrentDeviceIdResource>]>]> { @@ -2026,7 +2028,8 @@ def OpenACC_KernelsOp [AttrSizedOperandSegments, AutomaticAllocationScope, RecursiveMemoryEffects, DeclareOpInterfaceMethods<ComputeRegionOpInterface>, - DeclareOpInterfaceMethods<RegionBranchOpInterface>, + DeclareOpInterfaceMethods<RegionBranchOpInterface, + ["getSuccessorInputs"]>, OffloadRegionOpInterface, MemoryEffects<[MemWrite<OpenACC_ConstructResource>, MemRead<OpenACC_CurrentDeviceIdResource>]>]> { @@ -2208,7 +2211,8 @@ def OpenACC_KernelEnvironmentOp : OpenACC_Op<"kernel_environment", [AttrSizedOperandSegments, RecursiveMemoryEffects, SingleBlock, NoTerminator, - DeclareOpInterfaceMethods<RegionBranchOpInterface>, + DeclareOpInterfaceMethods<RegionBranchOpInterface, + ["getSuccessorInputs"]>, MemoryEffects<[MemWrite<OpenACC_ConstructResource>, MemRead<OpenACC_CurrentDeviceIdResource>]>]> { let summary = "Decomposition of compute constructs to capture data mapping " @@ -2261,7 +2265,8 @@ def OpenACC_KernelEnvironmentOp def OpenACC_DataOp : OpenACC_Op< "data", [AttrSizedOperandSegments, RecursiveMemoryEffects, - DeclareOpInterfaceMethods<RegionBranchOpInterface>, + DeclareOpInterfaceMethods<RegionBranchOpInterface, + ["getSuccessorInputs"]>, MemoryEffects<[MemWrite<OpenACC_ConstructResource>, MemRead<OpenACC_CurrentDeviceIdResource>]>]> { let summary = "data construct"; @@ -2537,7 +2542,8 @@ def OpenACC_ExitDataOp : OpenACC_Op<"exit_data", def OpenACC_HostDataOp : OpenACC_Op<"host_data", [AttrSizedOperandSegments, - DeclareOpInterfaceMethods<RegionBranchOpInterface>, + DeclareOpInterfaceMethods<RegionBranchOpInterface, + ["getSuccessorInputs"]>, MemoryEffects<[MemWrite<OpenACC_ConstructResource>, MemRead<OpenACC_CurrentDeviceIdResource>]>]> { let summary = "host_data construct"; @@ -2583,7 +2589,8 @@ def OpenACC_LoopOp RecursiveMemoryEffects, DeclareOpInterfaceMethods<ComputeRegionOpInterface>, DeclareOpInterfaceMethods<LoopLikeOpInterface>, - DeclareOpInterfaceMethods<RegionBranchOpInterface>, + DeclareOpInterfaceMethods<RegionBranchOpInterface, + ["getSuccessorInputs"]>, MemoryEffects<[MemWrite<OpenACC_ConstructResource>]>]> { let summary = "loop construct"; diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index 48a377491df02..a08cf3c95e6ce 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -77,7 +77,8 @@ def ConditionOp : SCF_Op<"condition", [ //===----------------------------------------------------------------------===// def ExecuteRegionOp : SCF_Op<"execute_region", [ - DeclareOpInterfaceMethods<RegionBranchOpInterface>, RecursiveMemoryEffects]> { + DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getSuccessorInputs"]>, + RecursiveMemoryEffects]> { let summary = "operation that executes its region exactly once"; let description = [{ The `scf.execute_region` operation is used to allow multiple blocks within SCF @@ -159,7 +160,7 @@ def ForOp : SCF_Op<"for", AllTypesMatch<["lowerBound", "upperBound", "step"]>, ConditionallySpeculatable, DeclareOpInterfaceMethods<RegionBranchOpInterface, - ["getEntrySuccessorOperands"]>, + ["getEntrySuccessorOperands", "getSuccessorInputs"]>, SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveMemoryEffects]> { let summary = "for operation"; @@ -699,7 +700,7 @@ def InParallelOp : SCF_Op<"forall.in_parallel", [ def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [ "getNumRegionInvocations", "getRegionInvocationBounds", - "getEntrySuccessorRegions"]>, + "getEntrySuccessorRegions", "getSuccessorInputs"]>, InferTypeOpAdaptor, SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveMemoryEffects, RecursivelySpeculatable, NoRegionArguments]> { let summary = "if-then-else operation"; @@ -982,7 +983,7 @@ def ReduceReturnOp : def WhileOp : SCF_Op<"while", [DeclareOpInterfaceMethods<RegionBranchOpInterface, - ["getEntrySuccessorOperands"]>, + ["getEntrySuccessorOperands", "getSuccessorInputs"]>, DeclareOpInterfaceMethods<LoopLikeOpInterface, ["getRegionIterArgs", "getYieldedValuesMutable"]>, RecursiveMemoryEffects, SingleBlock]> { @@ -1136,7 +1137,8 @@ def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects, SingleBlockImplicitTerminator<"scf::YieldOp">, DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getRegionInvocationBounds", - "getEntrySuccessorRegions"]>]> { + "getEntrySuccessorRegions", + "getSuccessorInputs"]>]> { let summary = "switch-case operation on an index argument"; let description = [{ The `scf.index_switch` is a control-flow operation that branches to one of diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index cbf1223298f90..fc9f498fdb805 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -818,7 +818,7 @@ def Shape_AssumingAllOp : Shape_Op<"assuming_all", [Commutative, Pure]> { def Shape_AssumingOp : Shape_Op<"assuming", [ SingleBlockImplicitTerminator<"AssumingYieldOp">, - DeclareOpInterfaceMethods<RegionBranchOpInterface>, + DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getSuccessorInputs"]>, RecursiveMemoryEffects]> { let summary = "Execute the region"; let description = [{ diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index a61d90a0c39b1..d4901645c51d1 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -1562,7 +1562,7 @@ def IterateOp : SparseTensor_Op<"iterate", ["getInitsMutable", "getLoopResults", "getRegionIterArgs", "getYieldedValuesMutable"]>, DeclareOpInterfaceMethods<RegionBranchOpInterface, - ["getEntrySuccessorOperands"]>, + ["getEntrySuccessorOperands", "getSuccessorInputs"]>, SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> { let summary = "Iterates over a sparse iteration space"; diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index ed69287410509..d0de4aaed310c 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -26,7 +26,8 @@ include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" def AlternativesOp : TransformDialectOp<"alternatives", [DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getEntrySuccessorOperands", - "getRegionInvocationBounds"]>, + "getRegionInvocationBounds", + "getSuccessorInputs"]>, DeclareOpInterfaceMethods<TransformOpInterface>, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, IsolatedFromAbove, PossibleTopLevelTransformOpTrait, @@ -624,7 +625,7 @@ def ForeachOp : TransformDialectOp<"foreach", [DeclareOpInterfaceMethods<TransformOpInterface>, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DeclareOpInterfaceMethods<RegionBranchOpInterface, [ - "getEntrySuccessorOperands"]>, + "getEntrySuccessorOperands", "getSuccessorInputs"]>, SingleBlockImplicitTerminator<"::mlir::transform::YieldOp"> ]> { let summary = "Executes the body for each element of the payload"; @@ -1238,7 +1239,8 @@ def SelectOp : TransformDialectOp<"select", def SequenceOp : TransformDialectOp<"sequence", [DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getEntrySuccessorOperands", - "getRegionInvocationBounds"]>, + "getRegionInvocationBounds", + "getSuccessorInputs"]>, MatchOpInterface, DeclareOpInterfaceMethods<TransformOpInterface>, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td index 4079848fd203a..eeb32486433e1 100644 --- a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td +++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td @@ -64,7 +64,8 @@ def KnobOp : Op<Transform_Dialect, "tune.knob", [ def AlternativesOp : Op<Transform_Dialect, "tune.alternatives", [ DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getEntrySuccessorOperands", - "getRegionInvocationBounds"]>, + "getRegionInvocationBounds", + "getSuccessorInputs"]>, DeclareOpInterfaceMethods<TransformOpInterface>, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">, diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h index 1e21348b4ea39..529afc7e43d27 100644 --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h @@ -207,18 +207,13 @@ class RegionSuccessor { public: /// Initialize a successor that branches to another region of the parent /// operation. - /// TODO: the default value for the regionInputs is somehow broken. - /// A region successor should have its input correctly set. - RegionSuccessor(Region *region, Block::BlockArgListType regionInputs = {}) - : successor(region), inputs(regionInputs) { + RegionSuccessor(Region *region) : successor(region) { assert(region && "Region must not be null"); } /// Initialize a successor that branches back to/out of the parent operation. /// The target must be one of the recursive parent operations. - static RegionSuccessor parent(Operation::result_range results) { - return RegionSuccessor(results); - } + static RegionSuccessor parent() { return RegionSuccessor(); } /// Return the given region successor. Returns nullptr if the successor is the /// parent operation. @@ -227,25 +222,21 @@ class RegionSuccessor { /// Return true if the successor is the parent operation. bool isParent() const { return successor == nullptr; } - /// Return the inputs to the successor that are remapped by the exit values of - /// the current region. - ValueRange getSuccessorInputs() const { return inputs; } - bool operator==(RegionSuccessor rhs) const { - return successor == rhs.successor && inputs == rhs.inputs; + return successor == rhs.successor; } + bool operator==(const Region *region) const { return successor == region; } + friend bool operator!=(RegionSuccessor lhs, RegionSuccessor rhs) { return !(lhs == rhs); } private: /// Private constructor to encourage the use of `RegionSuccessor::parent`. - RegionSuccessor(Operation::result_range results) - : successor(nullptr), inputs(ValueRange(results)) {} + RegionSuccessor() : successor(nullptr) {} Region *successor = nullptr; - ValueRange inputs; }; /// This class represents a point being branched from in the methods of the @@ -310,7 +301,7 @@ inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, if (successor.isParent()) return os << "<to parent>"; return os << "<to region #" << successor.getSuccessor()->getRegionNumber() - << " with " << successor.getSuccessorInputs().size() << " inputs>"; + << ">"; } /// This class represents upper and lower bounds on the number of times a region diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td index d1451552d7b0f..c9e62ac1a958c 100644 --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -262,6 +262,16 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> { regions); } }]>, + InterfaceMethod<[{ + Return all successor inputs for the given region successor. + }], + "::mlir::ValueRange", "getSuccessorInputs", + (ins "::mlir::RegionSuccessor":$successor), + [{}], + /*defaultImplementation=*/[{ + // Default implementation: No successor inputs. + return ::mlir::ValueRange(); + }]>, InterfaceMethod<[{ Returns the potential branching points (predecessors) for a given region successor. diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp index be53a4e56f37a..b698756dd75e9 100644 --- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp +++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp @@ -57,7 +57,7 @@ static void collectUnderlyingAddressValues2( LDBG() << " inputValue: " << inputValue; LDBG() << " inputIndex: " << inputIndex; LDBG() << " maxDepth: " << maxDepth; - ValueRange inputs = initialSuccessor.getSuccessorInputs(); + ValueRange inputs = branch.getSuccessorInputs(initialSuccessor); if (inputs.empty()) { LDBG() << " input is empty, enqueue value"; output.push_back(inputValue); @@ -108,9 +108,9 @@ static void collectUnderlyingAddressValues(OpResult result, unsigned maxDepth, // Check to see if we can reason about the control flow of this op. if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) { LDBG() << " Processing region branch operation"; - return collectUnderlyingAddressValues2( - branch, RegionSuccessor::parent(op->getResults()), result, - result.getResultNumber(), maxDepth, visited, output); + return collectUnderlyingAddressValues2(branch, RegionSuccessor::parent(), + result, result.getResultNumber(), + maxDepth, visited, output); } LDBG() << " Adding result to output: " << result; diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp index 7648d8ab4b532..3ce0f94e0c6da 100644 --- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp @@ -527,7 +527,8 @@ void DeadCodeAnalysis::visitRegionBranchEdges( auto *predecessors = getOrCreate<PredecessorState>(point); propagateIfChanged( predecessors, - predecessors->join(predecessorOp, successor.getSuccessorInputs())); + predecessors->join(predecessorOp, + regionBranchOp.getSuccessorInputs(successor))); LDBG() << "Added region branch as predecessor for successor: " << *point; } } diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp index a93e605445465..012d8384d3098 100644 --- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp @@ -138,7 +138,7 @@ LogicalResult IntegerRangeAnalysis::visitOperation( } void IntegerRangeAnalysis::visitNonControlFlowArguments( - Operation *op, const RegionSuccessor &successor, + Operation *op, const RegionSuccessor &successor, ValueRange successorInputs, ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) { if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) { LDBG() << "Inferring ranges for " @@ -208,7 +208,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments( loop.getLoopInductionVars(); if (!maybeIvs) { return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments( - op, successor, argLattices, firstIndex); + op, successor, successorInputs, argLattices, firstIndex); } // This shouldn't be returning nullopt if there are indunction variables. SmallVector<OpFoldResult> lowerBounds = *loop.getLoopLowerBounds(); @@ -246,5 +246,5 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments( } return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments( - op, successor, argLattices, firstIndex); + op, successor, successorInputs, argLattices, firstIndex); } diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp index bc236aa13db04..f86bb55df3ac5 100644 --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -134,8 +134,7 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) { // The results of a region branch operation are determined by control-flow. if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) { visitRegionSuccessors(getProgramPointAfter(branch), branch, - RegionSuccessor::parent(branch->getResults()), - resultLattices); + RegionSuccessor::parent(), resultLattices); return success(); } @@ -187,9 +186,9 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) { } // Otherwise, we can't reason about the data-flow. - return visitNonControlFlowArgumentsImpl(block->getParentOp(), - RegionSuccessor(block->getParent()), - argLattices, /*firstIndex=*/0); + return visitNonControlFlowArgumentsImpl( + block->getParentOp(), RegionSuccessor(block->getParent()), ValueRange(), + argLattices, /*firstIndex=*/0); } // Iterate over the predecessors of the non-entry block. @@ -316,19 +315,17 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors( if (!inputs.empty()) firstIndex = cast<OpResult>(inputs.front()).getResultNumber(); visitNonControlFlowArgumentsImpl( - branch, - RegionSuccessor::parent( - branch->getResults().slice(firstIndex, inputs.size())), - lattices, firstIndex); + branch, RegionSuccessor::parent(), + branch->getResults().slice(firstIndex, inputs.size()), lattices, + firstIndex); } else { if (!inputs.empty()) firstIndex = cast<BlockArgument>(inputs.front()).getArgNumber(); Region *region = point->getBlock()->getParent(); visitNonControlFlowArgumentsImpl( - branch, - RegionSuccessor(region, region->getArguments().slice( - firstIndex, inputs.size())), - lattices, firstIndex); + branch, RegionSuccessor(region), + region->getArguments().slice(firstIndex, inputs.size()), lattices, + firstIndex); } } @@ -620,7 +617,7 @@ void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors( SmallVector<BlockArgument> noControlFlowArguments; MutableArrayRef<BlockArgument> arguments = successor.getSuccessor()->getArguments(); - ValueRange inputs = successor.getSuccessorInputs(); + ValueRange inputs = branch.getSuccessorInputs(successor); for (BlockArgument argument : arguments) { // Visit blockArgument of RegionBranchOp which isn't "control // flow block arguments". For example, the IV of a loop. diff --git a/mlir/lib/Analysis/SliceWalk.cpp b/mlir/lib/Analysis/SliceWalk.cpp index 9baf856186979..5c5a68ef11b36 100644 --- a/mlir/lib/Analysis/SliceWalk.cpp +++ b/mlir/lib/Analysis/SliceWalk.cpp @@ -68,7 +68,7 @@ mlir::getControlFlowPredecessors(Value value) { if (!regionOp) return std::nullopt; // Add the control flow predecessor operands to the work list. - RegionSuccessor region = RegionSuccessor::parent(regionOp->getResults()); + RegionSuccessor region = RegionSuccessor::parent(); SmallVector<Value> predecessorOperands; // TODO (#175168): This assumes that there are no non-successor-inputs // in front of the op result. diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index df1b93e367fc6..84813810cfa57 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2741,18 +2741,18 @@ void AffineForOp::getSuccessorRegions( // From the loop body, if the trip count is one, we can only branch back // to the parent. if (tripCount == 1) { - regions.push_back(RegionSuccessor::parent(getResults())); + regions.push_back(RegionSuccessor::parent()); return; } if (tripCount == 0) return; } else { if (tripCount.value() > 0) { - regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); + regions.push_back(RegionSuccessor(&getRegion())); return; } if (tripCount.value() == 0) { - regions.push_back(RegionSuccessor::parent(getResults())); + regions.push_back(RegionSuccessor::parent()); return; } } @@ -2760,8 +2760,14 @@ void AffineForOp::getSuccessorRegions( // In all other cases, the loop may branch back to itself or the parent // operation. - regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); - regions.push_back(RegionSuccessor::parent(getResults())); + regions.push_back(RegionSuccessor(&getRegion())); + regions.push_back(RegionSuccessor::parent()); +} + +ValueRange AffineForOp::getSuccessorInputs(RegionSuccessor successor) { + if (successor.isParent()) + return getResults(); + return getRegionIterArgs(); } AffineBound AffineForOp::getLowerBound() { @@ -3146,21 +3152,29 @@ void AffineIfOp::getSuccessorRegions( // `else` region is valid. if (point.isParent()) { regions.reserve(2); - regions.push_back( - RegionSuccessor(&getThenRegion(), getThenRegion().getArguments())); + regions.push_back(RegionSuccessor(&getThenRegion())); // If the "else" region is empty, branch bach into parent. if (getElseRegion().empty()) { - regions.push_back(RegionSuccessor::parent(getResults())); + regions.push_back(RegionSuccessor::parent()); } else { - regions.push_back( - RegionSuccessor(&getElseRegion(), getElseRegion().getArguments())); + regions.push_back(RegionSuccessor(&getElseRegion())); } return; } // If the predecessor is the `else`/`then` region, then branching into parent // op is valid. - regions.push_back(RegionSuccessor::parent(getResults())); + regions.push_back(RegionSuccessor::parent()); +} + +ValueRange AffineIfOp::getSuccessorInputs(RegionSuccessor successor) { + if (successor.isParent()) + return getResults(); + if (successor == &getThenRegion()) + return getThenRegion().getArguments(); + if (successor == &getElseRegion()) + return getElseRegion().getArguments(); + llvm_unreachable("invalid region successor"); } LogicalResult AffineIfOp::verify() { diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp index 11fd87ed925d8..233c762b9de3c 100644 --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -55,13 +55,18 @@ void ExecuteOp::getSuccessorRegions(RegionBranchPoint point, if (!point.isParent() && point.getTerminatorPredecessorOrNull()->getParentRegion() == &getBodyRegion()) { - regions.push_back(RegionSuccessor::parent(getBodyResults())); + regions.push_back(RegionSuccessor::parent()); return; } // Otherwise the successor is the body region. - regions.push_back( - RegionSuccessor(&getBodyRegion(), getBodyRegion().getArguments())); + regions.push_back(RegionSuccessor(&getBodyRegion())); +} + +ValueRange ExecuteOp::getSuccessorInputs(RegionSuccessor successor) { + if (successor.isParent()) + return getBodyResults(); + return getBodyRegion().getArguments(); } void ExecuteOp::build(OpBuilder &builder, OperationState &result, diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp index 9e8746cb8ea35..6081e515d4e3a 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp @@ -563,9 +563,7 @@ BufferDeallocation::updateFunctionSignature(FunctionOpInterface op) { SmallVector<TypeRange> returnOperandTypes(llvm::map_range( op.getFunctionBody().getOps<RegionBranchTerminatorOpInterface>(), [&](RegionBranchTerminatorOpInterface branchOp) { - return branchOp - .getSuccessorOperands( - RegionSuccessor::parent(op.getOperation()->getResults())) + return branchOp.getSuccessorOperands(RegionSuccessor::parent()) .getTypes(); })); if (!llvm::all_equal(returnOperandTypes)) @@ -945,8 +943,8 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) { // about, but we would need to check how many successors there are and under // which condition they are taken, etc. - MutableOperandRange operands = op.getMutableSuccessorOperands( - RegionSuccessor::parent(op.getOperation()->getResults())); + MutableOperandRange operands = + op.getMutableSuccessorOperands(RegionSuccessor::parent()); SmallVector<Value> updatedOwnerships; auto result = deallocation_impl::insertDeallocOpForReturnLike( diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index ca29ff833535d..e6f8f58be61ca 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -878,7 +878,7 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { // The `then` and the `else` region branch back to the parent operation. if (!point.isParent()) { - regions.push_back(RegionSuccessor::parent(getOperation()->getResults())); + regions.push_back(RegionSuccessor::parent()); return; } @@ -887,11 +887,21 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point, // Don't consider the else region if it is empty. Region *elseRegion = &this->getElseRegion(); if (elseRegion->empty()) - regions.push_back(RegionSuccessor::parent(getOperation()->getResults())); + regions.push_back(RegionSuccessor::parent()); else regions.push_back(RegionSuccessor(elseRegion)); } +ValueRange IfOp::getSuccessorInputs(RegionSuccessor successor) { + if (successor.isParent()) + return getOperation()->getResults(); + if (successor == &getThenRegion()) + return ValueRange(); + if (successor == &getElseRegion()) + return ValueRange(); + llvm_unreachable("invalid region successor"); +} + void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> ®ions) { FoldAdaptor adaptor(operands, *this); @@ -904,8 +914,7 @@ void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands, if (!getElseRegion().empty()) regions.emplace_back(&getElseRegion()); else - regions.emplace_back( - RegionSuccessor::parent(getOperation()->getResults())); + regions.emplace_back(RegionSuccessor::parent()); } } diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index ed4be4dad6704..345a7ed54b578 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -2398,7 +2398,7 @@ ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser, void WarpExecuteOnLane0Op::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { if (!point.isParent()) { - regions.push_back(RegionSuccessor::parent(getResults())); + regions.push_back(RegionSuccessor::parent()); return; } @@ -2406,6 +2406,11 @@ void WarpExecuteOnLane0Op::getSuccessorRegions( regions.push_back(RegionSuccessor(&getWarpRegion())); } +ValueRange WarpExecuteOnLane0Op::getSuccessorInputs(RegionSuccessor successor) { + if (successor.isParent()) + return getResults(); + return ValueRange(); +} void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result, TypeRange resultTypes, Value laneId, int64_t warpSize) { diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 9a604d1f109de..c5e10f78286f4 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -405,13 +405,19 @@ ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) { void AllocaScopeOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { if (!point.isParent()) { - regions.push_back(RegionSuccessor::parent(getResults())); + regions.push_back(RegionSuccessor::parent()); return; } regions.push_back(RegionSuccessor(&getBodyRegion())); } +ValueRange AllocaScopeOp::getSuccessorInputs(RegionSuccessor successor) { + if (successor.isParent()) + return getResults(); + return ValueRange(); +} + /// Given an operation, return whether this op is guaranteed to /// allocate an AutomaticAllocationScopeResource static bool isGuaranteedAutomaticAllocation(Operation *op) { diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 50b4d0563faef..dc0b4c2db244b 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -411,7 +411,7 @@ getSingleRegionOpSuccessorRegions(Operation *op, Region ®ion, return; } - regions.push_back(RegionSuccessor::parent(op->getResults())); + regions.push_back(RegionSuccessor::parent()); } void KernelsOp::getSuccessorRegions(RegionBranchPoint point, @@ -420,36 +420,72 @@ void KernelsOp::getSuccessorRegions(RegionBranchPoint point, regions); } +ValueRange KernelsOp::getSuccessorInputs(RegionSuccessor successor) { + if (successor.isParent()) + return getOperation()->getResults(); + return ValueRange(); +} + void ParallelOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point, regions); } +ValueRange ParallelOp::getSuccessorInputs(RegionSuccessor successor) { + if (successor.isParent()) + return getOperation()->getResults(); + return ValueRange(); +} + void SerialOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point, regions); } +ValueRange SerialOp::getSuccessorInputs(RegionSuccessor successor) { + if (successor.isParent()) + return getOperation()->getResults(); + return ValueRange(); +} + void KernelEnvironmentOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point, regions); } +ValueRange KernelEnvironmentOp::getSuccessorInputs(RegionSuccessor successor) { + if (successor.isParent()) + return getOperation()->getResults(); + return ValueRange(); +} + void DataOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point, regions); } +ValueRange DataOp::getSuccessorInputs(RegionSuccessor successor) { + if (successor.isParent()) + return getOperation()->getResults(); + return ValueRange(); +} + void HostDataOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point, regions); } +ValueRange HostDataOp::getSuccessorInputs(RegionSuccessor successor) { + if (successor.isParent()) + return getOperation()->getResults(); + return ValueRange(); +} + void LoopOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { // Unstructured loops: the body may contain arbitrary CFG and early exits. @@ -460,13 +496,19 @@ void LoopOp::getSuccessorRegions(RegionBranchPoint point, regions.push_back(RegionSuccessor(&getRegion())); return; } - regions.push_back(RegionSuccessor::parent(getResults())); + regions.push_back(RegionSuccessor::parent()); return; } // Structured loops: model a loop-shaped region graph similar to scf.for. regions.push_back(RegionSuccessor(&getRegion())); - regions.push_back(RegionSuccessor::parent(getResults())); + regions.push_back(RegionSuccessor::parent()); +} + +ValueRange LoopOp::getSuccessorInputs(RegionSuccessor successor) { + if (successor.isParent()) + return getOperation()->getResults(); + return ValueRange(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 2075cad593abf..d3ec9f3dcf85f 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -307,7 +307,13 @@ void ExecuteRegionOp::getSuccessorRegions( } // Otherwise, the region branches back to the parent operation. - regions.push_back(RegionSuccessor::parent(getResults())); + regions.push_back(RegionSuccessor::parent()); +} + +ValueRange ExecuteRegionOp::getSuccessorInputs(RegionSuccessor successor) { + if (successor.isParent()) + return getOperation()->getResults(); + return ValueRange(); } //===----------------------------------------------------------------------===// @@ -334,10 +340,9 @@ void ConditionOp::getSuccessorRegions( // depending on whether the condition is true or not. auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition()); if (!boolAttr || boolAttr.getValue()) - regions.emplace_back(&whileOp.getAfter(), - whileOp.getAfter().getArguments()); + regions.emplace_back(&whileOp.getAfter()); if (!boolAttr || !boolAttr.getValue()) - regions.push_back(RegionSuccessor::parent(whileOp.getResults())); + regions.push_back(RegionSuccessor::parent()); } //===----------------------------------------------------------------------===// @@ -703,8 +708,14 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point, // Both the operation itself and the region may be branching into the body or // back into the operation itself. It is possible for loop not to enter the // body. - regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); - regions.push_back(RegionSuccessor::parent(getResults())); + regions.push_back(RegionSuccessor(&getRegion())); + regions.push_back(RegionSuccessor::parent()); +} + +ValueRange ForOp::getSuccessorInputs(RegionSuccessor successor) { + if (successor.isParent()) + return getResults(); + return getRegionIterArgs(); } SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; } @@ -1827,14 +1838,12 @@ void ForallOp::getSuccessorRegions(RegionBranchPoint point, regions.push_back(RegionSuccessor(&getRegion())); // However, when there are 0 threads, the control flow may branch back to // the parent immediately. - regions.push_back(RegionSuccessor::parent( - ResultRange{getResults().end(), getResults().end()})); + regions.push_back(RegionSuccessor::parent()); } else { // In accordance with the semantics of forall, its body is executed in // parallel by multiple threads. We should not expect to branch back into // the forall body after the region's execution is complete. - regions.push_back(RegionSuccessor::parent( - ResultRange{getResults().end(), getResults().end()})); + regions.push_back(RegionSuccessor::parent()); } } @@ -2116,7 +2125,7 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point, // The `then` and the `else` region branch back to the parent operation or one // of the recursive parent operations (early exit case). if (!point.isParent()) { - regions.push_back(RegionSuccessor::parent(getResults())); + regions.push_back(RegionSuccessor::parent()); return; } @@ -2125,11 +2134,17 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point, // Don't consider the else region if it is empty. Region *elseRegion = &this->getElseRegion(); if (elseRegion->empty()) - regions.push_back(RegionSuccessor::parent(getResults())); + regions.push_back(RegionSuccessor::parent()); else regions.push_back(RegionSuccessor(elseRegion)); } +ValueRange IfOp::getSuccessorInputs(RegionSuccessor successor) { + if (successor.isParent()) + return getOperation()->getResults(); + return ValueRange(); +} + void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> ®ions) { FoldAdaptor adaptor(operands, *this); @@ -2142,7 +2157,7 @@ void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands, if (!getElseRegion().empty()) regions.emplace_back(&getElseRegion()); else - regions.emplace_back(RegionSuccessor::parent(getResults())); + regions.emplace_back(RegionSuccessor::parent()); } } @@ -3157,8 +3172,7 @@ void ParallelOp::getSuccessorRegions( // back into the operation itself. It is possible for loop not to enter the // body. regions.push_back(RegionSuccessor(&getRegion())); - regions.push_back(RegionSuccessor::parent( - ResultRange{getResults().end(), getResults().end()})); + regions.push_back(RegionSuccessor::parent()); } //===----------------------------------------------------------------------===// @@ -3302,7 +3316,7 @@ void WhileOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { // The parent op always branches to the condition region. if (point.isParent()) { - regions.emplace_back(&getBefore(), getBefore().getArguments()); + regions.emplace_back(&getBefore()); return; } @@ -3313,12 +3327,22 @@ void WhileOp::getSuccessorRegions(RegionBranchPoint point, // The body region always branches back to the condition region. if (point.getTerminatorPredecessorOrNull()->getParentRegion() == &getAfter()) { - regions.emplace_back(&getBefore(), getBefore().getArguments()); + regions.emplace_back(&getBefore()); return; } - regions.push_back(RegionSuccessor::parent(getResults())); - regions.emplace_back(&getAfter(), getAfter().getArguments()); + regions.push_back(RegionSuccessor::parent()); + regions.emplace_back(&getAfter()); +} + +ValueRange WhileOp::getSuccessorInputs(RegionSuccessor successor) { + if (successor.isParent()) + return getOperation()->getResults(); + if (successor == &getBefore()) + return getBefore().getArguments(); + if (successor == &getAfter()) + return getAfter().getArguments(); + llvm_unreachable("invalid region successor"); } SmallVector<Region *> WhileOp::getLoopRegions() { @@ -3848,13 +3872,19 @@ void IndexSwitchOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) { // All regions branch back to the parent op. if (!point.isParent()) { - successors.push_back(RegionSuccessor::parent(getResults())); + successors.push_back(RegionSuccessor::parent()); return; } llvm::append_range(successors, getRegions()); } +ValueRange IndexSwitchOp::getSuccessorInputs(RegionSuccessor successor) { + if (successor.isParent()) + return getOperation()->getResults(); + return ValueRange(); +} + void IndexSwitchOp::getEntrySuccessorRegions( ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> &successors) { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 7de285976f42f..d3b61f6b0624a 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -346,13 +346,19 @@ void AssumingOp::getSuccessorRegions( // parent, so return the correct RegionSuccessor purely based on the index // being None or 0. if (!point.isParent()) { - regions.push_back(RegionSuccessor::parent(getResults())); + regions.push_back(RegionSuccessor::parent()); return; } regions.push_back(RegionSuccessor(&getDoRegion())); } +ValueRange AssumingOp::getSuccessorInputs(RegionSuccessor successor) { + if (successor.isParent()) + return getResults(); + return ValueRange(); +} + void AssumingOp::inlineRegionIntoParent(AssumingOp &op, PatternRewriter &rewriter) { auto *blockBeforeAssuming = rewriter.getInsertionBlock(); diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 55c0fbdc6f52f..60ef7cb28b778 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -2605,9 +2605,15 @@ void IterateOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { // Both the operation itself and the region may be branching into the body // or back into the operation itself. - regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); + regions.push_back(RegionSuccessor(&getRegion())); // It is possible for loop not to enter the body. - regions.push_back(RegionSuccessor::parent(getResults())); + regions.push_back(RegionSuccessor::parent()); +} + +ValueRange IterateOp::getSuccessorInputs(RegionSuccessor successor) { + if (successor.isParent()) + return getResults(); + return getRegionIterArgs(); } void CoIterateOp::build(OpBuilder &builder, OperationState &odsState, diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 4c5461d6f6ee6..2a8648a5d5a26 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -113,12 +113,17 @@ void transform::AlternativesOp::getSuccessorRegions( ->getParentRegion() ->getRegionNumber() + 1)) { - regions.emplace_back(&alternative, !getOperands().empty() - ? alternative.getArguments() - : Block::BlockArgListType()); + regions.emplace_back(&alternative); } if (!point.isParent()) - regions.push_back(RegionSuccessor::parent(getResults())); + regions.push_back(RegionSuccessor::parent()); +} + +ValueRange +transform::AlternativesOp::getSuccessorInputs(RegionSuccessor successor) { + if (successor.isParent()) + return getOperation()->getResults(); + return successor.getSuccessor()->getArguments(); } void transform::AlternativesOp::getRegionInvocationBounds( @@ -1738,7 +1743,7 @@ void transform::ForeachOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { Region *bodyRegion = &getBody(); if (point.isParent()) { - regions.emplace_back(bodyRegion, bodyRegion->getArguments()); + regions.emplace_back(bodyRegion); return; } @@ -1746,8 +1751,14 @@ void transform::ForeachOp::getSuccessorRegions( assert(point.getTerminatorPredecessorOrNull()->getParentRegion() == &getBody() && "unexpected region index"); - regions.emplace_back(bodyRegion, bodyRegion->getArguments()); - regions.push_back(RegionSuccessor::parent(getResults())); + regions.emplace_back(bodyRegion); + regions.push_back(RegionSuccessor::parent()); +} + +ValueRange transform::ForeachOp::getSuccessorInputs(RegionSuccessor successor) { + if (successor.isParent()) + return getResults(); + return getBody().getArguments(); } OperandRange @@ -2969,16 +2980,23 @@ void transform::SequenceOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { if (point.isParent()) { Region *bodyRegion = &getBody(); - regions.emplace_back(bodyRegion, getNumOperands() != 0 - ? bodyRegion->getArguments() - : Block::BlockArgListType()); + regions.emplace_back(bodyRegion); return; } assert(point.getTerminatorPredecessorOrNull()->getParentRegion() == &getBody() && "unexpected region index"); - regions.push_back(RegionSuccessor::parent(getResults())); + regions.push_back(RegionSuccessor::parent()); +} + +ValueRange +transform::SequenceOp::getSuccessorInputs(RegionSuccessor successor) { + if (getNumOperands() == 0) + return ValueRange(); + if (successor.isParent()) + return getResults(); + return getBody().getArguments(); } void transform::SequenceOp::getRegionInvocationBounds( diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp index fe81b9a1e7173..4f050b1e54e00 100644 --- a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp +++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp @@ -123,13 +123,19 @@ void transform::tune::AlternativesOp::getSuccessorRegions( if (point.isParent()) if (auto selectedRegionIdx = getSelectedRegionAttr()) regions.emplace_back( - &getAlternatives()[selectedRegionIdx->getSExtValue()], - Block::BlockArgListType()); + &getAlternatives()[selectedRegionIdx->getSExtValue()]); else for (Region &alternative : getAlternatives()) - regions.emplace_back(&alternative, Block::BlockArgListType()); + regions.emplace_back(&alternative); else - regions.push_back(RegionSuccessor::parent(getResults())); + regions.push_back(RegionSuccessor::parent()); +} + +ValueRange +transform::tune::AlternativesOp::getSuccessorInputs(RegionSuccessor successor) { + if (successor.isParent()) + return getOperation()->getResults(); + return ValueRange(); } void transform::tune::AlternativesOp::getRegionInvocationBounds( diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp index 684d9ec3c0d14..ebd4b63145f92 100644 --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -193,7 +193,7 @@ LogicalResult detail::verifyRegionBranchOpInterface(Operation *op) { // Verify number of successor operands and successor inputs. OperandRange succOperands = regionInterface.getSuccessorOperands(branchPoint, successor); - ValueRange succInputs = successor.getSuccessorInputs(); + ValueRange succInputs = regionInterface.getSuccessorInputs(successor); if (succOperands.size() != succInputs.size()) { return emitRegionEdgeError() << ": region branch point has " << succOperands.size() @@ -456,10 +456,10 @@ getSuccessorOperandInputMapping(RegionBranchOpInterface branchOp, branchOp.getSuccessorRegions(src, successors); for (RegionSuccessor dst : successors) { OperandRange operands = branchOp.getSuccessorOperands(src, dst); - assert(operands.size() == dst.getSuccessorInputs().size() && + assert(operands.size() == branchOp.getSuccessorInputs(dst).size() && "expected the same number of operands and inputs"); for (const auto &[operand, input] : llvm::zip_equal( - operandsToOpOperands(operands), dst.getSuccessorInputs())) + operandsToOpOperands(operands), branchOp.getSuccessorInputs(dst))) mapping[&operand].push_back(input); } } diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp index 148d8b10fcd5d..67f11e3ae39cb 100644 --- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -743,15 +743,27 @@ void RegionIfOp::getSuccessorRegions( if (!point.isParent()) { if (point.getTerminatorPredecessorOrNull()->getParentRegion() != &getJoinRegion()) - regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs())); + regions.push_back(RegionSuccessor(&getJoinRegion())); else - regions.push_back(RegionSuccessor::parent(getResults())); + regions.push_back(RegionSuccessor::parent()); return; } // The then and else regions are the entry regions of this op. - regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs())); - regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs())); + regions.push_back(RegionSuccessor(&getThenRegion())); + regions.push_back(RegionSuccessor(&getElseRegion())); +} + +ValueRange RegionIfOp::getSuccessorInputs(RegionSuccessor successor) { + if (successor.isParent()) + return getResults(); + if (successor == &getThenRegion()) + return getThenArgs(); + if (successor == &getElseRegion()) + return getElseArgs(); + if (successor == &getJoinRegion()) + return getJoinArgs(); + llvm_unreachable("invalid region successor"); } void RegionIfOp::getRegionInvocationBounds( @@ -772,7 +784,13 @@ void AnyCondOp::getSuccessorRegions(RegionBranchPoint point, if (point.isParent()) regions.emplace_back(&getRegion()); else - regions.push_back(RegionSuccessor::parent(getResults())); + regions.push_back(RegionSuccessor::parent()); +} + +ValueRange AnyCondOp::getSuccessorInputs(RegionSuccessor successor) { + if (successor.isParent()) + return getResults(); + return ValueRange(); } void AnyCondOp::getRegionInvocationBounds( @@ -1228,11 +1246,17 @@ LogicalResult TestOpWithPropertiesAndInferredType::inferReturnTypes( void LoopBlockOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { - regions.emplace_back(&getBody(), getBody().getArguments()); + regions.emplace_back(&getBody()); if (point.isParent()) return; - regions.push_back(RegionSuccessor::parent(getOperation()->getResults())); + regions.push_back(RegionSuccessor::parent()); +} + +ValueRange LoopBlockOp::getSuccessorInputs(RegionSuccessor successor) { + if (successor.isParent()) + return getOperation()->getResults(); + return getBody().getArguments(); } OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionSuccessor successor) { @@ -1336,9 +1360,15 @@ MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() { void TestStoreWithARegion::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { if (point.isParent()) - regions.emplace_back(&getBody(), getBody().front().getArguments()); + regions.emplace_back(&getBody()); else - regions.push_back(RegionSuccessor::parent(getOperation()->getResults())); + regions.push_back(RegionSuccessor::parent()); +} + +ValueRange TestStoreWithARegion::getSuccessorInputs(RegionSuccessor successor) { + if (successor.isParent()) + return getOperation()->getResults(); + return getBody().front().getArguments(); } //===----------------------------------------------------------------------===// @@ -1350,9 +1380,15 @@ void TestStoreWithALoopRegion::getSuccessorRegions( // Both the operation itself and the region may be branching into the body or // back into the operation itself. It is possible for the operation not to // enter the body. - regions.emplace_back( - RegionSuccessor(&getBody(), getBody().front().getArguments())); - regions.push_back(RegionSuccessor::parent(getOperation()->getResults())); + regions.emplace_back(&getBody()); + regions.push_back(RegionSuccessor::parent()); +} + +ValueRange +TestStoreWithALoopRegion::getSuccessorInputs(RegionSuccessor successor) { + if (successor.isParent()) + return getOperation()->getResults(); + return getBody().front().getArguments(); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 5417ae94f00d7..cd8656306509e 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2635,7 +2635,8 @@ def RegionIfYieldOp : TEST_Op<"region_if_yield", def RegionIfOp : TEST_Op<"region_if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getRegionInvocationBounds", - "getEntrySuccessorOperands"]>, + "getEntrySuccessorOperands", + "getSuccessorInputs"]>, SingleBlockImplicitTerminator<"RegionIfYieldOp">, RecursiveMemoryEffects]> { let description =[{ @@ -2665,7 +2666,8 @@ def RegionIfOp : TEST_Op<"region_if", def AnyCondOp : TEST_Op<"any_cond", [DeclareOpInterfaceMethods<RegionBranchOpInterface, - ["getRegionInvocationBounds"]>, + ["getRegionInvocationBounds", + "getSuccessorInputs"]>, RecursiveMemoryEffects]> { let results = (outs Variadic<AnyType>:$results); let regions = (region AnyRegion:$region); @@ -2673,7 +2675,8 @@ def AnyCondOp : TEST_Op<"any_cond", def LoopBlockOp : TEST_Op<"loop_block", [DeclareOpInterfaceMethods<RegionBranchOpInterface, - ["getEntrySuccessorOperands"]>, RecursiveMemoryEffects]> { + ["getEntrySuccessorOperands", "getSuccessorInputs"]>, + RecursiveMemoryEffects]> { let results = (outs F32:$floatResult); let arguments = (ins I32:$init); @@ -3741,7 +3744,7 @@ def TestCallOnDeviceOp : TEST_Op<"call_on_device", } def TestStoreWithARegion : TEST_Op<"store_with_a_region", - [DeclareOpInterfaceMethods<RegionBranchOpInterface>, + [DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getSuccessorInputs"]>, SingleBlock]> { let arguments = (ins Arg<AnyMemRef, "", [MemWrite]>:$address, @@ -3753,7 +3756,7 @@ def TestStoreWithARegion : TEST_Op<"store_with_a_region", } def TestStoreWithALoopRegion : TEST_Op<"store_with_a_loop_region", - [DeclareOpInterfaceMethods<RegionBranchOpInterface>, + [DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getSuccessorInputs"]>, SingleBlock]> { let arguments = (ins Arg<AnyMemRef, "", [MemWrite]>:$address, diff --git a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp index 24cef9afbac1c..655520cc2005c 100644 --- a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp +++ b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp @@ -67,12 +67,17 @@ struct LoopRegionsOp point.getTerminatorPredecessorOrNull()->getParentRegion(); if (region == &(*this)->getRegion(1)) // This region also branches back to the parent. - regions.push_back( - RegionSuccessor::parent(getOperation()->getResults())); + regions.push_back(RegionSuccessor::parent()); regions.push_back(RegionSuccessor(region)); } } + ValueRange getSuccessorInputs(RegionSuccessor successor) { + if (successor.isParent()) + return getOperation()->getResults(); + return ValueRange(); + } + using RegionBranchOpInterface::Trait<LoopRegionsOp>::getSuccessorRegions; }; @@ -92,10 +97,17 @@ struct DoubleLoopRegionsOp if (point.getTerminatorPredecessorOrNull()) { Region *region = point.getTerminatorPredecessorOrNull()->getParentRegion(); - regions.push_back(RegionSuccessor::parent(getOperation()->getResults())); + regions.push_back(RegionSuccessor::parent()); regions.push_back(RegionSuccessor(region)); } } + + ValueRange getSuccessorInputs(RegionSuccessor successor) { + if (successor.isParent()) + return getOperation()->getResults(); + return ValueRange(); + } + using RegionBranchOpInterface::Trait< DoubleLoopRegionsOp>::getSuccessorRegions; }; _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
