[llvm-branch-commits] [mlir] [mlir][Transforms] Dialect conversion: Simplify handling of dropped arguments (PR #97213)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Sun Jul 14 02:06:55 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

This commit simplifies the handling of dropped arguments and updates some dialect conversion documentation that is outdated.

When converting a block signature, a `BlockTypeConversionRewrite` object and potentially multiple `ReplaceBlockArgRewrite` are created. During the "commit" phase, uses of the old block arguments are replaced with the new block arguments, but the old implementation was written in an inconsistent way: some block arguments were replaced in `BlockTypeConversionRewrite::commit` and some were replaced in `ReplaceBlockArgRewrite::commit`. The new
`BlockTypeConversionRewrite::commit` implementation is much simpler and no longer modifies any IR; that is done only in `ReplaceBlockArgRewrite` now. The `ConvertedArgInfo` data structure is no longer needed.

To that end, materializations of dropped arguments are now built in `applySignatureConversion` instead of `materializeLiveConversions`; the latter function no longer has to deal with dropped arguments.

Other minor improvements:
- Add more comments to `applySignatureConversion`.

Note: Error messages around failed materializations for dropped basic block arguments changed slightly. That is because those materializations are now built in `legalizeUnresolvedMaterialization` instead of `legalizeConvertedArgumentTypes`.

This commit is in preparation of decoupling argument/source/target materializations from the dialect conversion.

This is a re-upload of #<!-- -->96207.


---
Full diff: https://github.com/llvm/llvm-project/pull/97213.diff


2 Files Affected:

- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+55-118) 
- (modified) mlir/test/Transforms/test-legalize-type-conversion.mlir (+2-4) 


``````````diff
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 1e0afee2373a9..0b552a7e1ca3b 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -432,34 +432,14 @@ class MoveBlockRewrite : public BlockRewrite {
   Block *insertBeforeBlock;
 };
 
-/// This structure contains the information pertaining to an argument that has
-/// been converted.
-struct ConvertedArgInfo {
-  ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
-                   Value castValue = nullptr)
-      : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
-
-  /// The start index of in the new argument list that contains arguments that
-  /// replace the original.
-  unsigned newArgIdx;
-
-  /// The number of arguments that replaced the original argument.
-  unsigned newArgSize;
-
-  /// The cast value that was created to cast from the new arguments to the
-  /// old. This only used if 'newArgSize' > 1.
-  Value castValue;
-};
-
 /// Block type conversion. This rewrite is partially reflected in the IR.
 class BlockTypeConversionRewrite : public BlockRewrite {
 public:
-  BlockTypeConversionRewrite(
-      ConversionPatternRewriterImpl &rewriterImpl, Block *block,
-      Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo,
-      const TypeConverter *converter)
+  BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+                             Block *block, Block *origBlock,
+                             const TypeConverter *converter)
       : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
-        origBlock(origBlock), argInfo(argInfo), converter(converter) {}
+        origBlock(origBlock), converter(converter) {}
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::BlockTypeConversion;
@@ -479,10 +459,6 @@ class BlockTypeConversionRewrite : public BlockRewrite {
   /// The original block that was requested to have its signature converted.
   Block *origBlock;
 
-  /// The conversion information for each of the arguments. The information is
-  /// std::nullopt if the argument was dropped during conversion.
-  SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
-
   /// The type converter used to convert the arguments.
   const TypeConverter *converter;
 };
@@ -691,12 +667,16 @@ class CreateOperationRewrite : public OperationRewrite {
 /// The type of materialization.
 enum MaterializationKind {
   /// This materialization materializes a conversion for an illegal block
-  /// argument type, to a legal one.
+  /// argument type, to the original one.
   Argument,
 
   /// This materialization materializes a conversion from an illegal type to a
   /// legal one.
-  Target
+  Target,
+
+  /// This materialization materializes a conversion from a legal type back to
+  /// an illegal one.
+  Source
 };
 
 /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
@@ -736,7 +716,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
 private:
   /// The corresponding type converter to use when resolving this
   /// materialization, and the kind of this materialization.
-  llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
+  llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
       converterAndKind;
 };
 } // namespace
@@ -855,11 +835,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
                                        ValueRange inputs, Type outputType,
                                        const TypeConverter *converter);
 
-  Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
-                                               ValueRange inputs,
-                                               Type outputType,
-                                               const TypeConverter *converter);
-
   Value buildUnresolvedTargetMaterialization(Location loc, Value input,
                                              Type outputType,
                                              const TypeConverter *converter);
@@ -989,28 +964,6 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
           dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
     for (Operation *op : block->getUsers())
       listener->notifyOperationModified(op);
-
-  // Process the remapping for each of the original arguments.
-  for (auto [origArg, info] :
-       llvm::zip_equal(origBlock->getArguments(), argInfo)) {
-    // Handle the case of a 1->0 value mapping.
-    if (!info) {
-      if (Value newArg =
-              rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
-        rewriter.replaceAllUsesWith(origArg, newArg);
-      continue;
-    }
-
-    // Otherwise this is a 1->1+ value mapping.
-    Value castValue = info->castValue;
-    assert(info->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
-
-    // If the argument is still used, replace it with the generated cast.
-    if (!origArg.use_empty()) {
-      rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault(
-                                               castValue, origArg.getType()));
-    }
-  }
 }
 
 void BlockTypeConversionRewrite::rollback() {
@@ -1035,14 +988,12 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
       continue;
 
     Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
-    bool isDroppedArg = replacementValue == origArg;
-    if (!isDroppedArg)
-      builder.setInsertionPointAfterValue(replacementValue);
+    assert(replacementValue && "replacement value not found");
     Value newArg;
     if (converter) {
+      builder.setInsertionPointAfterValue(replacementValue);
       newArg = converter->materializeSourceConversion(
-          builder, origArg.getLoc(), origArg.getType(),
-          isDroppedArg ? ValueRange() : ValueRange(replacementValue));
+          builder, origArg.getLoc(), origArg.getType(), replacementValue);
       assert((!newArg || newArg.getType() == origArg.getType()) &&
              "materialization hook did not provide a value of the expected "
              "type");
@@ -1053,8 +1004,6 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
           << "failed to materialize conversion for block argument #"
           << it.index() << " that remained live after conversion, type was "
           << origArg.getType();
-      if (!isDroppedArg)
-        diag << ", with target type " << replacementValue.getType();
       diag.attachNote(liveUser->getLoc())
           << "see existing live user here: " << *liveUser;
       return failure();
@@ -1340,73 +1289,64 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
   // Replace all uses of the old block with the new block.
   block->replaceAllUsesWith(newBlock);
 
-  // Remap each of the original arguments as determined by the signature
-  // conversion.
-  SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
-  argInfo.resize(origArgCount);
-
   for (unsigned i = 0; i != origArgCount; ++i) {
-    auto inputMap = signatureConversion.getInputMapping(i);
-    if (!inputMap)
-      continue;
     BlockArgument origArg = block->getArgument(i);
+    Type origArgType = origArg.getType();
+
+    std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
+        signatureConversion.getInputMapping(i);
+    if (!inputMap) {
+      // This block argument was dropped and no replacement value was provided.
+      // Materialize a replacement value "out of thin air".
+      Value repl = buildUnresolvedMaterialization(
+          MaterializationKind::Source, newBlock, newBlock->begin(),
+          origArg.getLoc(), /*inputs=*/ValueRange(),
+          /*outputType=*/origArgType, converter);
+      mapping.map(origArg, repl);
+      appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
+      continue;
+    }
 
-    // If inputMap->replacementValue is not nullptr, then the argument is
-    // dropped and a replacement value is provided to be the remappedValue.
-    if (inputMap->replacementValue) {
+    if (Value repl = inputMap->replacementValue) {
+      // This block argument was dropped and a replacement value was provided.
       assert(inputMap->size == 0 &&
              "invalid to provide a replacement value when the argument isn't "
              "dropped");
-      mapping.map(origArg, inputMap->replacementValue);
+      mapping.map(origArg, repl);
       appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
       continue;
     }
 
-    // Otherwise, this is a 1->1+ mapping.
+    // This is a 1->1+ mapping. 1->N mappings are not fully supported in the
+    // dialect conversion. Therefore, we need an argument materialization to
+    // turn the replacement block arguments into a single SSA value that can be
+    // used as a replacement.
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
-    Value newArg;
+    Value argMat = buildUnresolvedMaterialization(
+        MaterializationKind::Argument, newBlock, newBlock->begin(),
+        origArg.getLoc(), /*inputs=*/replArgs, origArgType, converter);
+    mapping.map(origArg, argMat);
+    appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
 
-    // If this is a 1->1 mapping and the types of new and replacement arguments
-    // match (i.e. it's an identity map), then the argument is mapped to its
-    // original type.
     // FIXME: We simply pass through the replacement argument if there wasn't a
     // converter, which isn't great as it allows implicit type conversions to
     // appear. We should properly restructure this code to handle cases where a
     // converter isn't provided and also to properly handle the case where an
     // argument materialization is actually a temporary source materialization
     // (e.g. in the case of 1->N).
-    if (replArgs.size() == 1 &&
-        (!converter || replArgs[0].getType() == origArg.getType())) {
-      newArg = replArgs.front();
-      mapping.map(origArg, newArg);
-    } else {
-      // Build argument materialization: new block arguments -> old block
-      // argument type.
-      Value argMat = buildUnresolvedArgumentMaterialization(
-          newBlock, origArg.getLoc(), replArgs, origArg.getType(), converter);
-      mapping.map(origArg, argMat);
-
-      // Build target materialization: old block argument type -> legal type.
-      // Note: This function returns an "empty" type if no valid conversion to
-      // a legal type exists. In that case, we continue the conversion with the
-      // original block argument type.
-      Type legalOutputType = converter->convertType(origArg.getType());
-      if (legalOutputType && legalOutputType != origArg.getType()) {
-        newArg = buildUnresolvedTargetMaterialization(
-            origArg.getLoc(), argMat, legalOutputType, converter);
-        mapping.map(argMat, newArg);
-      } else {
-        newArg = argMat;
-      }
+    Type legalOutputType;
+    if (converter)
+      legalOutputType = converter->convertType(origArgType);
+    if (legalOutputType && legalOutputType != origArgType) {
+      Value targetMat = buildUnresolvedTargetMaterialization(
+          origArg.getLoc(), argMat, legalOutputType, converter);
+      mapping.map(argMat, targetMat);
     }
-
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
-    argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
   }
 
-  appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
-                                            converter);
+  appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);
 
   // Erase the old block. (It is just unlinked for now and will be erased during
   // cleanup.)
@@ -1437,13 +1377,6 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
   return convertOp.getResult(0);
 }
-Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
-    Block *block, Location loc, ValueRange inputs, Type outputType,
-    const TypeConverter *converter) {
-  return buildUnresolvedMaterialization(MaterializationKind::Argument, block,
-                                        block->begin(), loc, inputs, outputType,
-                                        converter);
-}
 Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
     Location loc, Value input, Type outputType,
     const TypeConverter *converter) {
@@ -2862,6 +2795,10 @@ static LogicalResult legalizeUnresolvedMaterialization(
       newMaterialization = converter->materializeTargetConversion(
           rewriter, op->getLoc(), outputType, inputOperands);
       break;
+    case MaterializationKind::Source:
+      newMaterialization = converter->materializeSourceConversion(
+          rewriter, op->getLoc(), outputType, inputOperands);
+      break;
     }
     if (newMaterialization) {
       assert(newMaterialization.getType() == outputType &&
@@ -2874,8 +2811,8 @@ static LogicalResult legalizeUnresolvedMaterialization(
 
   InFlightDiagnostic diag = op->emitError()
                             << "failed to legalize unresolved materialization "
-                               "from "
-                            << inputOperands.getTypes() << " to " << outputType
+                               "from ("
+                            << inputOperands.getTypes() << ") to " << outputType
                             << " that remained live after conversion";
   if (Operation *liveUser = findLiveUser(op->getUsers())) {
     diag.attachNote(liveUser->getLoc())
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index b35cda8e724f6..8254be68912c8 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -2,9 +2,8 @@
 
 
 func.func @test_invalid_arg_materialization(
-  // expected-error at below {{failed to materialize conversion for block argument #0 that remained live after conversion, type was 'i16'}}
+  // expected-error at below {{failed to legalize unresolved materialization from () to 'i16' that remained live after conversion}}
   %arg0: i16) {
-  // expected-note at below {{see existing live user here}}
   "foo.return"(%arg0) : (i16) -> ()
 }
 
@@ -104,9 +103,8 @@ func.func @test_block_argument_not_converted() {
 // Make sure argument type changes aren't implicitly forwarded.
 func.func @test_signature_conversion_no_converter() {
   "test.signature_conversion_no_converter"() ({
-  // expected-error at below {{failed to materialize conversion for block argument #0 that remained live after conversion}}
+  // expected-error at below {{failed to legalize unresolved materialization from ('f64') to 'f32' that remained live after conversion}}
   ^bb0(%arg0: f32):
-    // expected-note at below {{see existing live user here}}
     "test.type_consumer"(%arg0) : (f32) -> ()
     "test.return"(%arg0) : (f32) -> ()
   }) : () -> ()

``````````

</details>


https://github.com/llvm/llvm-project/pull/97213


More information about the llvm-branch-commits mailing list