llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

This commit simplifies the handling of dropped arguments and updates some 
dialect conversion documentation that is outdated.

When converting a block signature, a `BlockTypeConversionRewrite` object and 
potentially multiple `ReplaceBlockArgRewrite` are created. During the "commit" 
phase, uses of the old block arguments are replaced with the new block 
arguments, but the old implementation was written in an inconsistent way: some 
block arguments were replaced in `BlockTypeConversionRewrite::commit` and some 
were replaced in `ReplaceBlockArgRewrite::commit`. The new 
`BlockTypeConversionRewrite::commit` implementation is much simpler and no 
longer modifies any IR; that is done only in `ReplaceBlockArgRewrite` now. The 
`ConvertedArgInfo` data structure is no longer needed.

To that end, materializations of dropped arguments are now built in 
`applySignatureConversion` instead of `materializeLiveConversions`; the latter 
function no longer has to deal with dropped arguments.

Other minor improvements:
- Improve variable name: `origOutputType` -&gt; `origArgType`. Add an assertion 
to check that this field is only used for argument materializations.
- Add more comments to `applySignatureConversion`.

Note: Error messages around failed materializations for dropped basic block 
arguments changed slightly. That is because those materializations are now 
built in `legalizeUnresolvedMaterialization` instead of 
`legalizeConvertedArgumentTypes`.

This commit is in preparation of decoupling argument/source/target 
materializations from the dialect conversion.


---

Patch is 24.50 KiB, truncated to 20.00 KiB below, full version: 
https://github.com/llvm/llvm-project/pull/96207.diff


4 Files Affected:

- (modified) mlir/docs/DialectConversion.md (+25-12) 
- (modified) mlir/include/mlir/Transforms/DialectConversion.h (+6-4) 
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+78-130) 
- (modified) mlir/test/Transforms/test-legalize-type-conversion.mlir (+2-4) 


``````````diff
diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index 69781bb868bbf..f722974a9a1e5 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -246,6 +246,13 @@ depending on the situation.
 
     -   An argument materialization is used when converting the type of a block
         argument during a [signature conversion](#region-signature-conversion).
+        The new block argument types are specified in a `SignatureConversion`
+        object. An original block argument can be converted into multiple
+        block arguments, which is not supported everywhere in the dialect
+        conversion. (E.g., adaptors support only a single replacement value for
+        each original value.) Therefore, an argument materialization is used to
+        convert potentially multiple new block arguments back into a single SSA
+        value.
 
 *   Source Materialization
 
@@ -259,6 +266,9 @@ depending on the situation.
         *   When a block argument has been converted to a different type, but
             the original argument still has users that will remain live after
             the conversion process has finished.
+        *   When a block argument has been dropped, but the argument still has
+            users that will remain live after the conversion process has
+            finished.
         *   When the result type of an operation has been converted to a
             different type, but the original result still has users that will
             remain live after the conversion process is finished.
@@ -330,17 +340,19 @@ class TypeConverter {
 
   /// Register a materialization function, which must be convertible to the
   /// following form:
-  ///   `Optional<Value> (OpBuilder &, T, ValueRange, Location)`,
-  ///   where `T` is any subclass of `Type`.
-  /// This function is responsible for creating an operation, using the
-  /// OpBuilder and Location provided, that "converts" a range of values into a
-  /// single value of the given type `T`. It must return a Value of the
-  /// converted type on success, an `std::nullopt` if it failed but other
-  /// materialization can be attempted, and `nullptr` on unrecoverable failure.
-  /// It will only be called for (sub)types of `T`.
+  ///   `std::optional<Value>(OpBuilder &, T, ValueRange, Location)`,
+  /// where `T` is any subclass of `Type`. This function is responsible for
+  /// creating an operation, using the OpBuilder and Location provided, that
+  /// "casts" a range of values into a single value of the given type `T`. It
+  /// must return a Value of the converted type on success, an `std::nullopt` 
if
+  /// it failed but other materialization can be attempted, and `nullptr` on
+  /// unrecoverable failure. It will only be called for (sub)types of `T`.
+  /// Materialization functions must be provided when a type conversion may
+  /// persist after the conversion has finished.
   ///
   /// This method registers a materialization that will be called when
-  /// converting an illegal block argument type, to a legal type.
+  /// converting potentially multiple replacement block arguments (of a single
+  /// original block argument), to a single SSA value with a legal type.
   template <typename FnT,
             typename T = typename llvm::function_traits<FnT>::template 
arg_t<1>>
   void addArgumentMaterialization(FnT &&callback) {
@@ -348,8 +360,9 @@ class TypeConverter {
         wrapMaterialization<T>(std::forward<FnT>(callback)));
   }
   /// This method registers a materialization that will be called when
-  /// converting a legal type to an illegal source type. This is used when
-  /// conversions to an illegal type must persist beyond the main conversion.
+  /// converting a legal replacement value back to an illegal source type.
+  /// This is used when some uses of the original, illegal value must persist
+  /// beyond the main conversion.
   template <typename FnT,
             typename T = typename llvm::function_traits<FnT>::template 
arg_t<1>>
   void addSourceMaterialization(FnT &&callback) {
@@ -357,7 +370,7 @@ class TypeConverter {
         wrapMaterialization<T>(std::forward<FnT>(callback)));
   }
   /// This method registers a materialization that will be called when
-  /// converting type from an illegal, or source, type to a legal type.
+  /// converting an illegal (source) value to a legal (target) type.
   template <typename FnT,
             typename T = typename llvm::function_traits<FnT>::template 
arg_t<1>>
   void addTargetMaterialization(FnT &&callback) {
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h 
b/mlir/include/mlir/Transforms/DialectConversion.h
index f83f3a3fdf992..87b5dd9a6f340 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -181,7 +181,8 @@ class TypeConverter {
   /// persist after the conversion has finished.
   ///
   /// This method registers a materialization that will be called when
-  /// converting an illegal block argument type, to a legal type.
+  /// converting potentially multiple replacement block arguments (of a single
+  /// original block argument), to a single SSA value with a legal type.
   template <typename FnT, typename T = typename llvm::function_traits<
                               std::decay_t<FnT>>::template arg_t<1>>
   void addArgumentMaterialization(FnT &&callback) {
@@ -189,8 +190,9 @@ class TypeConverter {
         wrapMaterialization<T>(std::forward<FnT>(callback)));
   }
   /// This method registers a materialization that will be called when
-  /// converting a legal type to an illegal source type. This is used when
-  /// conversions to an illegal type must persist beyond the main conversion.
+  /// converting a legal replacement value back to an illegal source type.
+  /// This is used when some uses of the original, illegal value must persist
+  /// beyond the main conversion.
   template <typename FnT, typename T = typename llvm::function_traits<
                               std::decay_t<FnT>>::template arg_t<1>>
   void addSourceMaterialization(FnT &&callback) {
@@ -198,7 +200,7 @@ class TypeConverter {
         wrapMaterialization<T>(std::forward<FnT>(callback)));
   }
   /// This method registers a materialization that will be called when
-  /// converting type from an illegal, or source, type to a legal type.
+  /// converting an illegal (source) value to a legal (target) type.
   template <typename FnT, typename T = typename llvm::function_traits<
                               std::decay_t<FnT>>::template arg_t<1>>
   void addTargetMaterialization(FnT &&callback) {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp 
b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index e6c0ee2ab2949..07ebd687ee2b3 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -432,34 +432,14 @@ class MoveBlockRewrite : public BlockRewrite {
   Block *insertBeforeBlock;
 };
 
-/// This structure contains the information pertaining to an argument that has
-/// been converted.
-struct ConvertedArgInfo {
-  ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
-                   Value castValue = nullptr)
-      : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
-
-  /// The start index of in the new argument list that contains arguments that
-  /// replace the original.
-  unsigned newArgIdx;
-
-  /// The number of arguments that replaced the original argument.
-  unsigned newArgSize;
-
-  /// The cast value that was created to cast from the new arguments to the
-  /// old. This only used if 'newArgSize' > 1.
-  Value castValue;
-};
-
 /// Block type conversion. This rewrite is partially reflected in the IR.
 class BlockTypeConversionRewrite : public BlockRewrite {
 public:
-  BlockTypeConversionRewrite(
-      ConversionPatternRewriterImpl &rewriterImpl, Block *block,
-      Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1> 
argInfo,
-      const TypeConverter *converter)
+  BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+                             Block *block, Block *origBlock,
+                             const TypeConverter *converter)
       : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
-        origBlock(origBlock), argInfo(argInfo), converter(converter) {}
+        origBlock(origBlock), converter(converter) {}
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::BlockTypeConversion;
@@ -479,10 +459,6 @@ class BlockTypeConversionRewrite : public BlockRewrite {
   /// The original block that was requested to have its signature converted.
   Block *origBlock;
 
-  /// The conversion information for each of the arguments. The information is
-  /// std::nullopt if the argument was dropped during conversion.
-  SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
-
   /// The type converter used to convert the arguments.
   const TypeConverter *converter;
 };
@@ -696,7 +672,11 @@ enum MaterializationKind {
 
   /// This materialization materializes a conversion from an illegal type to a
   /// legal one.
-  Target
+  Target,
+
+  /// This materialization materializes a conversion from a legal type back to
+  /// an illegal one.
+  Source
 };
 
 /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
@@ -708,9 +688,13 @@ class UnresolvedMaterializationRewrite : public 
OperationRewrite {
       ConversionPatternRewriterImpl &rewriterImpl,
       UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
       MaterializationKind kind = MaterializationKind::Target,
-      Type origOutputType = nullptr)
+      Type origArgType = nullptr)
       : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
-        converterAndKind(converter, kind), origOutputType(origOutputType) {}
+        converterAndKind(converter, kind), origArgType(origArgType) {
+    assert(kind == MaterializationKind::Argument ||
+           !origArgType && "orginal argument type make sense only for argument 
"
+                           "materializations");
+  }
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -734,17 +718,17 @@ class UnresolvedMaterializationRewrite : public 
OperationRewrite {
     return converterAndKind.getInt();
   }
 
-  /// Return the original illegal output type of the input values.
-  Type getOrigOutputType() const { return origOutputType; }
+  /// Return the original type of the block argument.
+  Type getOrigArgType() const { return origArgType; }
 
 private:
   /// The corresponding type converter to use when resolving this
   /// materialization, and the kind of this materialization.
-  llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
+  llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
       converterAndKind;
 
   /// The original output type. This is only used for argument conversions.
-  Type origOutputType;
+  Type origArgType;
 };
 } // namespace
 
@@ -862,13 +846,6 @@ struct ConversionPatternRewriterImpl : public 
RewriterBase::Listener {
                                        ValueRange inputs, Type outputType,
                                        Type origOutputType,
                                        const TypeConverter *converter);
-
-  Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
-                                               ValueRange inputs,
-                                               Type origOutputType,
-                                               Type outputType,
-                                               const TypeConverter *converter);
-
   Value buildUnresolvedTargetMaterialization(Location loc, Value input,
                                              Type outputType,
                                              const TypeConverter *converter);
@@ -998,28 +975,6 @@ void BlockTypeConversionRewrite::commit(RewriterBase 
&rewriter) {
           dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
     for (Operation *op : block->getUsers())
       listener->notifyOperationModified(op);
-
-  // Process the remapping for each of the original arguments.
-  for (auto [origArg, info] :
-       llvm::zip_equal(origBlock->getArguments(), argInfo)) {
-    // Handle the case of a 1->0 value mapping.
-    if (!info) {
-      if (Value newArg =
-              rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
-        rewriter.replaceAllUsesWith(origArg, newArg);
-      continue;
-    }
-
-    // Otherwise this is a 1->1+ value mapping.
-    Value castValue = info->castValue;
-    assert(info->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
-
-    // If the argument is still used, replace it with the generated cast.
-    if (!origArg.use_empty()) {
-      rewriter.replaceAllUsesWith(origArg, 
rewriterImpl.mapping.lookupOrDefault(
-                                               castValue, origArg.getType()));
-    }
-  }
 }
 
 void BlockTypeConversionRewrite::rollback() {
@@ -1043,15 +998,13 @@ LogicalResult 
BlockTypeConversionRewrite::materializeLiveConversions(
     if (!liveUser)
       continue;
 
-    Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
-    bool isDroppedArg = replacementValue == origArg;
-    if (!isDroppedArg)
-      builder.setInsertionPointAfterValue(replacementValue);
+    Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
+    assert(replacementValue && "replacement value not found");
     Value newArg;
     if (converter) {
+      builder.setInsertionPointAfterValue(replacementValue);
       newArg = converter->materializeSourceConversion(
-          builder, origArg.getLoc(), origArg.getType(),
-          isDroppedArg ? ValueRange() : ValueRange(replacementValue));
+          builder, origArg.getLoc(), origArg.getType(), replacementValue);
       assert((!newArg || newArg.getType() == origArg.getType()) &&
              "materialization hook did not provide a value of the expected "
              "type");
@@ -1062,8 +1015,6 @@ LogicalResult 
BlockTypeConversionRewrite::materializeLiveConversions(
           << "failed to materialize conversion for block argument #"
           << it.index() << " that remained live after conversion, type was "
           << origArg.getType();
-      if (!isDroppedArg)
-        diag << ", with target type " << replacementValue.getType();
       diag.attachNote(liveUser->getLoc())
           << "see existing live user here: " << *liveUser;
       return failure();
@@ -1349,65 +1300,65 @@ Block 
*ConversionPatternRewriterImpl::applySignatureConversion(
   // Replace all uses of the old block with the new block.
   block->replaceAllUsesWith(newBlock);
 
-  // Remap each of the original arguments as determined by the signature
-  // conversion.
-  SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
-  argInfo.resize(origArgCount);
-
   for (unsigned i = 0; i != origArgCount; ++i) {
-    auto inputMap = signatureConversion.getInputMapping(i);
-    if (!inputMap)
-      continue;
     BlockArgument origArg = block->getArgument(i);
+    Type origArgType = origArg.getType();
 
-    // If inputMap->replacementValue is not nullptr, then the argument is
-    // dropped and a replacement value is provided to be the remappedValue.
-    if (inputMap->replacementValue) {
-      assert(inputMap->size == 0 &&
-             "invalid to provide a replacement value when the argument isn't "
-             "dropped");
-      mapping.map(origArg, inputMap->replacementValue);
-      appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
-      continue;
-    }
-
-    // Otherwise, this is a 1->1+ mapping.
-    auto replArgs =
-        newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
-    Value newArg;
-
-    // If this is a 1->1 mapping and the types of new and replacement arguments
-    // match (i.e. it's an identity map), then the argument is mapped to its
-    // original type.
+    // Helper function that tries to legalize the given type. Returns the given
+    // type if it could not be legalized.
     // FIXME: We simply pass through the replacement argument if there wasn't a
     // converter, which isn't great as it allows implicit type conversions to
     // appear. We should properly restructure this code to handle cases where a
     // converter isn't provided and also to properly handle the case where an
     // argument materialization is actually a temporary source materialization
     // (e.g. in the case of 1->N).
-    if (replArgs.size() == 1 &&
-        (!converter || replArgs[0].getType() == origArg.getType())) {
-      newArg = replArgs.front();
-    } else {
-      Type origOutputType = origArg.getType();
+    auto tryLegalizeType = [&](Type type) {
+      if (converter)
+        if (Type t = converter->convertType(type))
+          return t;
+      return type;
+    };
 
-      // Legalize the argument output type.
-      Type outputType = origOutputType;
-      if (Type legalOutputType = converter->convertType(outputType))
-        outputType = legalOutputType;
+    std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
+        signatureConversion.getInputMapping(i);
+    if (!inputMap) {
+      // This block argument was dropped and no replacement value was provided.
+      // Materialize a replacement value "out of thin air".
+      Value repl = buildUnresolvedMaterialization(
+          MaterializationKind::Source, newBlock, newBlock->begin(),
+          origArg.getLoc(), /*inputs=*/ValueRange(),
+          /*outputType=*/origArgType, /*origArgType=*/{}, converter);
+      mapping.map(origArg, repl);
+      appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
+      continue;
+    }
 
-      newArg = buildUnresolvedArgumentMaterialization(
-          newBlock, origArg.getLoc(), replArgs, origOutputType, outputType,
-          converter);
+    if (Value repl = inputMap->replacementValue) {
+      // This block argument was dropped and a replacement value was provided.
+      assert(inputMap->size == 0 &&
+             "invalid to provide a replacement value when the argument isn't "
+             "dropped");
+      mapping.map(origArg, repl);
+      appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
+      continue;
     }
 
-    mapping.map(origArg, newArg);
+    // This is a 1->1+ mapping. 1->N mappings are not fully supported in the
+    // dialect conversion. Therefore, we need an argument materialization to
+    // turn the replacement block arguments into a single SSA value that can be
+    // used as a replacement. The type of this SSA value is the legalized
+    // version of the original block argument type.
+    auto replArgs =
+        newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
+    Value repl = buildUnresolvedMaterialization(
+        MaterializationKind::Argument, newBlock, newBlock->begin(),
+        origArg.getLoc(), /*inputs=*/replArgs,
+        /*outputType=*/tryLegalizeType(origArgType), origArgType, converter);
+    mapping.map(origArg, repl);
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
-    argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
   }
 
-  appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
-                                            converter);
+  appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);
 
   // Erase the old block. (It is just unlinked for now and will be erased 
during
   // cleanup.)
@@ -1424,7 +1375,7 @@ Block 
*ConversionPatternRewriterImpl::applySignatureConversion(
 /// of input operands.
 Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
     MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
-    Location loc, ValueRange inputs, Type outputType, Type origOutputType,
+    Location loc, ValueRange inputs, Type outputType, Type origArgType,
     const TypeConverter *converter) {
   // Avoid materializing an unnecessary cast.
   if (inputs.size() == 1 && inputs.front().getType() == outputType)
@@ -1436,16 +1387,9 @@ Value 
ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   auto convertOp =
       builder.create<UnrealizedConversionCastOp>(loc, out...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/96207
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to