llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) <details> <summary>Changes</summary> Until now, `transform.apply_conversion_patterns` consumed the target handle and potentially invalidated handles. This commit adds tracking functionality similar to `transform.apply_patterns`, such that handles are no longer invalidated, but updated based on op replacements performed by the dialect conversion. This new functionality is hidden behind a `preserve_handles` attribute for now. --- Full diff: https://github.com/llvm/llvm-project/pull/83950.diff 5 Files Affected: - (modified) mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h (+24-8) - (modified) mlir/include/mlir/Dialect/Transform/IR/TransformOps.td (+14-4) - (modified) mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp (+21-18) - (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+40-5) - (modified) mlir/test/Dialect/Transform/test-pattern-application.mlir (+30) ``````````diff diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h index 313cdc27f780a7..32724ff4b98e8e 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -921,20 +921,36 @@ TransformState::RegionScope TransformState::make_region_scope(Region ®ion) { return RegionScope(*this, region); } +/// A configuration object for customizing a `TrackingListener`. +struct TrackingListenerConfig { + using SkipHandleFn = std::function<bool(Value)>; + + /// An optional function that returns "true" for handles that do not have to + /// be updated. These are typically dead or consumed handles. + SkipHandleFn skipHandleFn = nullptr; + + /// If set to "true", the name of a replacement op must match the name of the + /// original op. If set to "false", the names of the payload ops tracked in a + /// handle may change as the tracking listener updates the transform state. + bool requireMatchingReplacementOpName = true; + + /// If set to "true", cast ops (that implement the CastOpInterface) are + /// skipped and the replacement op search continues with the operands of the + /// cast op. + bool skipCastOps = true; +}; + /// A listener that updates a TransformState based on IR modifications. This /// listener can be used during a greedy pattern rewrite to keep the transform /// state up-to-date. class TrackingListener : public RewriterBase::Listener, public TransformState::Extension { public: - /// A function that returns "true" for handles that do not have to be updated. - using SkipHandleFn = std::function<bool(Value)>; - /// Create a new TrackingListener for usage in the specified transform op. /// Optionally, a function can be specified to identify handles that should /// do not have to be updated. TrackingListener(TransformState &state, TransformOpInterface op, - SkipHandleFn skipHandleFn = nullptr); + TrackingListenerConfig config = TrackingListenerConfig()); protected: /// Return a replacement payload op for the given op, which is going to be @@ -959,7 +975,8 @@ class TrackingListener : public RewriterBase::Listener, /// same computation; e.g., there may be tiled "linalg.generic" inside the /// loop body that represents the original computation. Therefore, the /// TrackingListener is conservative by default: it drops the mapping and - /// triggers the "payload replacement not found" notification. + /// triggers the "payload replacement not found" notification. This default + /// behavior can be customized in `TrackingListenerConfig`. /// /// If no replacement op could be found according to the rules mentioned /// above, this function tries to skip over cast-like ops that implement @@ -1023,9 +1040,8 @@ class TrackingListener : public RewriterBase::Listener, /// The handles that are consumed by the transform op. DenseSet<Value> consumedHandles; - /// Handles for which this function evaluates to "true" do not have to be - /// updated. These are typically dead or consumed handles. - SkipHandleFn skipHandleFn; + /// Tracking listener configuration. + TrackingListenerConfig config; }; /// A specialized listener that keeps track of cases in which no replacement diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index 9f513822ed0a4e..0e42d12a69a400 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -190,11 +190,20 @@ def ApplyConversionPatternsOp : TransformDialectOp<"apply_conversion_patterns", The `legal_ops`, `illegal_ops`, `legal_dialects`, `illegal_dialects` attributes specify the conversion target. - This transform consumes the `target` handle and modifies the payload. It - does not produce any handles. + This transform modifies the payload. By default, it consumes the `target` + handle. It does not produce any handles. + + If the `preserve_handles` attribute is set, this transform does not consume + the `target` handle and instead updates handles based on notifications from + a tracking listener that is attached to the dialect conversion, similar to + `transform.apply_patterns`. Only replacements via `RewriterBase::replaceOp` + or `replaceOpWithNewOp` are considered "payload op replacements". In + contrast to `transform.apply_patterns`, we allow replacement ops even if the + op name has changed. More details can be found at the documentation site of + `TrackingListener`. This transform produces a silenceable failure if the dialect conversion was - unsuccessful. + unsuccessful or the tracking listener failed to find a replacement op. }]; let arguments = (ins TransformHandleTypeInterface:$target, @@ -202,7 +211,8 @@ def ApplyConversionPatternsOp : TransformDialectOp<"apply_conversion_patterns", OptionalAttr<StrArrayAttr>:$illegal_ops, OptionalAttr<StrArrayAttr>:$legal_dialects, OptionalAttr<StrArrayAttr>:$illegal_dialects, - UnitAttr:$partial_conversion); + UnitAttr:$partial_conversion, + UnitAttr:$preserve_handles); let results = (outs); let regions = (region MaxSizedRegion<1>:$patterns, diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp index bb9f6fec452986..71a9d61198e3fb 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -918,7 +918,8 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { } // Prepare rewriter and listener. - TrackingListener::SkipHandleFn skipHandleFn = [&](Value handle) { + TrackingListenerConfig config; + config.skipHandleFn = [&](Value handle) { // Skip handle if it is dead. auto scopeIt = llvm::find_if(llvm::reverse(regionStack), [&](RegionScope *scope) { @@ -935,7 +936,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { return true; }; transform::ErrorCheckingTrackingListener trackingListener(*this, transform, - skipHandleFn); + config); transform::TransformRewriter rewriter(transform->getContext(), &trackingListener); @@ -1184,9 +1185,8 @@ bool transform::TransformResults::isSet(unsigned resultNumber) const { transform::TrackingListener::TrackingListener(TransformState &state, TransformOpInterface op, - SkipHandleFn skipHandleFn) - : TransformState::Extension(state), transformOp(op), - skipHandleFn(skipHandleFn) { + TrackingListenerConfig config) + : TransformState::Extension(state), transformOp(op), config(config) { if (op) { for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) { consumedHandles.insert(opOperand->get()); @@ -1228,8 +1228,19 @@ DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp( return diag; } - // If the defining op has the same type, we take it as a replacement. - if (op->getName() == defOp->getName()) { + // Skip through ops that implement CastOpInterface. + if (config.skipCastOps && isa<CastOpInterface>(defOp)) { + values.clear(); + values.assign(defOp->getOperands().begin(), defOp->getOperands().end()); + diag.attachNote(defOp->getLoc()) + << "using output of 'CastOpInterface' op"; + continue; + } + + // If the defining op has the same name or we do not care about the name of + // op replacements at all, we take it as a replacement. + if (!config.requireMatchingReplacementOpName || + op->getName() == defOp->getName()) { result = defOp; return DiagnosedSilenceableFailure::success(); } @@ -1251,14 +1262,6 @@ DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp( "'FindPayloadReplacementOpInterface'"; continue; } - - // Skip through ops that implement CastOpInterface. - if (isa<CastOpInterface>(defOp)) { - values.assign(defOp->getOperands().begin(), defOp->getOperands().end()); - diag.attachNote(defOp->getLoc()) - << "using output of 'CastOpInterface' op"; - continue; - } } while (!values.empty()); diag.attachNote() << "ran out of suitable replacement values"; @@ -1318,9 +1321,9 @@ void transform::TrackingListener::notifyOperationReplaced( // Check if there are any handles that must be updated. Value aliveHandle; - if (skipHandleFn) { - auto it = - llvm::find_if(opHandles, [&](Value v) { return !skipHandleFn(v); }); + if (config.skipHandleFn) { + auto it = llvm::find_if(opHandles, + [&](Value v) { return !config.skipHandleFn(v); }); if (it != opHandles.end()) aliveHandle = *it; } else if (!opHandles.empty()) { diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 180d11c30e65de..ca80899ab07341 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -563,6 +563,17 @@ DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply( } } + // Attach a tracking listener if handles should be preserved. We configure the + // listener to allow op replacements with different names, as conversion + // patterns typically replace ops with replacement ops that have a different + // name. + TrackingListenerConfig trackingConfig; + trackingConfig.requireMatchingReplacementOpName = false; + ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig); + ConversionConfig conversionConfig; + if (getPreserveHandles()) + conversionConfig.listener = &trackingListener; + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); for (Operation *target : state.getPayloadOps(getTarget())) { // Make sure that this transform is not applied to itself. Modifying the @@ -574,16 +585,36 @@ DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply( LogicalResult status = failure(); if (getPartialConversion()) { - status = applyPartialConversion(target, conversionTarget, frozenPatterns); + status = applyPartialConversion(target, conversionTarget, frozenPatterns, + conversionConfig); } else { - status = applyFullConversion(target, conversionTarget, frozenPatterns); + status = applyFullConversion(target, conversionTarget, frozenPatterns, + conversionConfig); } + // Check dialect conversion state. + DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success(); if (failed(status)) { - auto diag = emitSilenceableError() << "dialect conversion failed"; + diag = emitSilenceableError() << "dialect conversion failed"; diag.attachNote(target->getLoc()) << "target op"; - return diag; } + + // Check tracking listener error state. + DiagnosedSilenceableFailure trackingFailure = + trackingListener.checkAndResetError(); + if (!trackingFailure.succeeded()) { + if (diag.succeeded()) { + // Tracking failure is the only failure. + return trackingFailure; + } else { + diag.attachNote() << "tracking listener also failed: " + << trackingFailure.getMessage(); + (void)trackingFailure.silence(); + } + } + + if (!diag.succeeded()) + return diag; } return DiagnosedSilenceableFailure::success(); @@ -632,7 +663,11 @@ LogicalResult transform::ApplyConversionPatternsOp::verify() { void transform::ApplyConversionPatternsOp::getEffects( SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { - transform::consumesHandle(getTarget(), effects); + if (!getPreserveHandles()) { + transform::consumesHandle(getTarget(), effects); + } else { + transform::onlyReadsHandle(getTarget(), effects); + } transform::modifiesPayload(effects); } diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir index 0c41e81b17b522..fa8a555af92188 100644 --- a/mlir/test/Dialect/Transform/test-pattern-application.mlir +++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir @@ -417,3 +417,33 @@ module attributes { transform.with_named_sequence } { transform.yield } } + +// ----- + +// "test.foo" is tracked and replaced with "test.new_op" during a dialect +// conversion. Make sure that the handle is updated accordingly. + +// CHECK-LABEL: func @dialect_conversion_tracking +// CHECK-NEXT: %[[m:.*]] = "test.new_op"() {annotated} : () -> memref<5xf32> +// CHECK-NEXT: %[[cast:.*]] = builtin.unrealized_conversion_cast %0 : memref<5xf32> to tensor<5xf32> +// CHECK-NEXT: return %[[cast]] +func.func @dialect_conversion_tracking() -> tensor<5xf32> { + %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (tensor<5xf32>) + return %0 : tensor<5xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op) { + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["test.foo"]} in %0 : (!transform.any_op) -> !transform.any_op + transform.apply_conversion_patterns to %0 { + transform.apply_conversion_patterns.transform.test_conversion_patterns + } with type_converter { + transform.apply_conversion_patterns.transform.test_type_converter + } {legal_ops = ["func.func", "func.return", "test.new_op"], preserve_handles} + : !transform.any_op + // Add an attribute to %1, which is now mapped to a new op. + transform.annotate %1 "annotated" : !transform.any_op + transform.yield + } +} `````````` </details> https://github.com/llvm/llvm-project/pull/83950 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits