[Mlir-commits] [mlir] [mlir][Transforms] Dialect conversion: Simplify handling of dropped arguments (PR #97213)

Matthias Springer llvmlistbot at llvm.org
Sat Jul 20 00:24:41 PDT 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/97213

>From 4114d5be87596e11d86706a338248ebf05cf7150 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 13 Jul 2024 17:36:37 +0200
Subject: [PATCH] [mlir][Transforms] Dialect conversion: Simplify handling of
 dropped arguments

This commit simplifies the handling of dropped arguments and updates some dialect conversion documentation that is outdated.

When converting a block signature, a BlockTypeConversionRewrite object and potentially multiple ReplaceBlockArgRewrite are created. During the "commit" phase, uses of the old block arguments are replaced with the new block arguments, but the old implementation was written in an inconsistent way: some block arguments were replaced in BlockTypeConversionRewrite::commit and some were replaced in ReplaceBlockArgRewrite::commit. The new
BlockTypeConversionRewrite::commit implementation is much simpler and no longer modifies any IR; that is done only in ReplaceBlockArgRewrite now. The ConvertedArgInfo data structure is no longer needed.

To that end, materializations of dropped arguments are now built in applySignatureConversion instead of materializeLiveConversions; the latter function no longer has to deal with dropped arguments.

Other minor improvements:

Improve variable name: origOutputType -> origArgType. Add an assertion to check that this field is only used for argument materializations.
Add more comments to applySignatureConversion.
Note: Error messages around failed materializations for dropped basic block arguments changed slightly. That is because those materializations are now built in legalizeUnresolvedMaterialization instead of legalizeConvertedArgumentTypes.

This commit is in preparation of decoupling argument/source/target materializations from the dialect conversion.

This is a re-upload of #96207.
---
 .../Transforms/Utils/DialectConversion.cpp    | 173 ++++++------------
 .../test-legalize-type-conversion.mlir        |   6 +-
 2 files changed, 57 insertions(+), 122 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 1e0afee2373a9..0b552a7e1ca3b 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -432,34 +432,14 @@ class MoveBlockRewrite : public BlockRewrite {
   Block *insertBeforeBlock;
 };
 
-/// This structure contains the information pertaining to an argument that has
-/// been converted.
-struct ConvertedArgInfo {
-  ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
-                   Value castValue = nullptr)
-      : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
-
-  /// The start index of in the new argument list that contains arguments that
-  /// replace the original.
-  unsigned newArgIdx;
-
-  /// The number of arguments that replaced the original argument.
-  unsigned newArgSize;
-
-  /// The cast value that was created to cast from the new arguments to the
-  /// old. This only used if 'newArgSize' > 1.
-  Value castValue;
-};
-
 /// Block type conversion. This rewrite is partially reflected in the IR.
 class BlockTypeConversionRewrite : public BlockRewrite {
 public:
-  BlockTypeConversionRewrite(
-      ConversionPatternRewriterImpl &rewriterImpl, Block *block,
-      Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo,
-      const TypeConverter *converter)
+  BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+                             Block *block, Block *origBlock,
+                             const TypeConverter *converter)
       : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
-        origBlock(origBlock), argInfo(argInfo), converter(converter) {}
+        origBlock(origBlock), converter(converter) {}
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::BlockTypeConversion;
@@ -479,10 +459,6 @@ class BlockTypeConversionRewrite : public BlockRewrite {
   /// The original block that was requested to have its signature converted.
   Block *origBlock;
 
-  /// The conversion information for each of the arguments. The information is
-  /// std::nullopt if the argument was dropped during conversion.
-  SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
-
   /// The type converter used to convert the arguments.
   const TypeConverter *converter;
 };
@@ -691,12 +667,16 @@ class CreateOperationRewrite : public OperationRewrite {
 /// The type of materialization.
 enum MaterializationKind {
   /// This materialization materializes a conversion for an illegal block
-  /// argument type, to a legal one.
+  /// argument type, to the original one.
   Argument,
 
   /// This materialization materializes a conversion from an illegal type to a
   /// legal one.
-  Target
+  Target,
+
+  /// This materialization materializes a conversion from a legal type back to
+  /// an illegal one.
+  Source
 };
 
 /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
@@ -736,7 +716,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
 private:
   /// The corresponding type converter to use when resolving this
   /// materialization, and the kind of this materialization.
-  llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
+  llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
       converterAndKind;
 };
 } // namespace
@@ -855,11 +835,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
                                        ValueRange inputs, Type outputType,
                                        const TypeConverter *converter);
 
-  Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
-                                               ValueRange inputs,
-                                               Type outputType,
-                                               const TypeConverter *converter);
-
   Value buildUnresolvedTargetMaterialization(Location loc, Value input,
                                              Type outputType,
                                              const TypeConverter *converter);
@@ -989,28 +964,6 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
           dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
     for (Operation *op : block->getUsers())
       listener->notifyOperationModified(op);
-
-  // Process the remapping for each of the original arguments.
-  for (auto [origArg, info] :
-       llvm::zip_equal(origBlock->getArguments(), argInfo)) {
-    // Handle the case of a 1->0 value mapping.
-    if (!info) {
-      if (Value newArg =
-              rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
-        rewriter.replaceAllUsesWith(origArg, newArg);
-      continue;
-    }
-
-    // Otherwise this is a 1->1+ value mapping.
-    Value castValue = info->castValue;
-    assert(info->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
-
-    // If the argument is still used, replace it with the generated cast.
-    if (!origArg.use_empty()) {
-      rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault(
-                                               castValue, origArg.getType()));
-    }
-  }
 }
 
 void BlockTypeConversionRewrite::rollback() {
@@ -1035,14 +988,12 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
       continue;
 
     Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
-    bool isDroppedArg = replacementValue == origArg;
-    if (!isDroppedArg)
-      builder.setInsertionPointAfterValue(replacementValue);
+    assert(replacementValue && "replacement value not found");
     Value newArg;
     if (converter) {
+      builder.setInsertionPointAfterValue(replacementValue);
       newArg = converter->materializeSourceConversion(
-          builder, origArg.getLoc(), origArg.getType(),
-          isDroppedArg ? ValueRange() : ValueRange(replacementValue));
+          builder, origArg.getLoc(), origArg.getType(), replacementValue);
       assert((!newArg || newArg.getType() == origArg.getType()) &&
              "materialization hook did not provide a value of the expected "
              "type");
@@ -1053,8 +1004,6 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
           << "failed to materialize conversion for block argument #"
           << it.index() << " that remained live after conversion, type was "
           << origArg.getType();
-      if (!isDroppedArg)
-        diag << ", with target type " << replacementValue.getType();
       diag.attachNote(liveUser->getLoc())
           << "see existing live user here: " << *liveUser;
       return failure();
@@ -1340,73 +1289,64 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
   // Replace all uses of the old block with the new block.
   block->replaceAllUsesWith(newBlock);
 
-  // Remap each of the original arguments as determined by the signature
-  // conversion.
-  SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
-  argInfo.resize(origArgCount);
-
   for (unsigned i = 0; i != origArgCount; ++i) {
-    auto inputMap = signatureConversion.getInputMapping(i);
-    if (!inputMap)
-      continue;
     BlockArgument origArg = block->getArgument(i);
+    Type origArgType = origArg.getType();
+
+    std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
+        signatureConversion.getInputMapping(i);
+    if (!inputMap) {
+      // This block argument was dropped and no replacement value was provided.
+      // Materialize a replacement value "out of thin air".
+      Value repl = buildUnresolvedMaterialization(
+          MaterializationKind::Source, newBlock, newBlock->begin(),
+          origArg.getLoc(), /*inputs=*/ValueRange(),
+          /*outputType=*/origArgType, converter);
+      mapping.map(origArg, repl);
+      appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
+      continue;
+    }
 
-    // If inputMap->replacementValue is not nullptr, then the argument is
-    // dropped and a replacement value is provided to be the remappedValue.
-    if (inputMap->replacementValue) {
+    if (Value repl = inputMap->replacementValue) {
+      // This block argument was dropped and a replacement value was provided.
       assert(inputMap->size == 0 &&
              "invalid to provide a replacement value when the argument isn't "
              "dropped");
-      mapping.map(origArg, inputMap->replacementValue);
+      mapping.map(origArg, repl);
       appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
       continue;
     }
 
-    // Otherwise, this is a 1->1+ mapping.
+    // This is a 1->1+ mapping. 1->N mappings are not fully supported in the
+    // dialect conversion. Therefore, we need an argument materialization to
+    // turn the replacement block arguments into a single SSA value that can be
+    // used as a replacement.
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
-    Value newArg;
+    Value argMat = buildUnresolvedMaterialization(
+        MaterializationKind::Argument, newBlock, newBlock->begin(),
+        origArg.getLoc(), /*inputs=*/replArgs, origArgType, converter);
+    mapping.map(origArg, argMat);
+    appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
 
-    // If this is a 1->1 mapping and the types of new and replacement arguments
-    // match (i.e. it's an identity map), then the argument is mapped to its
-    // original type.
     // FIXME: We simply pass through the replacement argument if there wasn't a
     // converter, which isn't great as it allows implicit type conversions to
     // appear. We should properly restructure this code to handle cases where a
     // converter isn't provided and also to properly handle the case where an
     // argument materialization is actually a temporary source materialization
     // (e.g. in the case of 1->N).
-    if (replArgs.size() == 1 &&
-        (!converter || replArgs[0].getType() == origArg.getType())) {
-      newArg = replArgs.front();
-      mapping.map(origArg, newArg);
-    } else {
-      // Build argument materialization: new block arguments -> old block
-      // argument type.
-      Value argMat = buildUnresolvedArgumentMaterialization(
-          newBlock, origArg.getLoc(), replArgs, origArg.getType(), converter);
-      mapping.map(origArg, argMat);
-
-      // Build target materialization: old block argument type -> legal type.
-      // Note: This function returns an "empty" type if no valid conversion to
-      // a legal type exists. In that case, we continue the conversion with the
-      // original block argument type.
-      Type legalOutputType = converter->convertType(origArg.getType());
-      if (legalOutputType && legalOutputType != origArg.getType()) {
-        newArg = buildUnresolvedTargetMaterialization(
-            origArg.getLoc(), argMat, legalOutputType, converter);
-        mapping.map(argMat, newArg);
-      } else {
-        newArg = argMat;
-      }
+    Type legalOutputType;
+    if (converter)
+      legalOutputType = converter->convertType(origArgType);
+    if (legalOutputType && legalOutputType != origArgType) {
+      Value targetMat = buildUnresolvedTargetMaterialization(
+          origArg.getLoc(), argMat, legalOutputType, converter);
+      mapping.map(argMat, targetMat);
     }
-
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
-    argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
   }
 
-  appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
-                                            converter);
+  appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);
 
   // Erase the old block. (It is just unlinked for now and will be erased during
   // cleanup.)
@@ -1437,13 +1377,6 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
   return convertOp.getResult(0);
 }
-Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
-    Block *block, Location loc, ValueRange inputs, Type outputType,
-    const TypeConverter *converter) {
-  return buildUnresolvedMaterialization(MaterializationKind::Argument, block,
-                                        block->begin(), loc, inputs, outputType,
-                                        converter);
-}
 Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
     Location loc, Value input, Type outputType,
     const TypeConverter *converter) {
@@ -2862,6 +2795,10 @@ static LogicalResult legalizeUnresolvedMaterialization(
       newMaterialization = converter->materializeTargetConversion(
           rewriter, op->getLoc(), outputType, inputOperands);
       break;
+    case MaterializationKind::Source:
+      newMaterialization = converter->materializeSourceConversion(
+          rewriter, op->getLoc(), outputType, inputOperands);
+      break;
     }
     if (newMaterialization) {
       assert(newMaterialization.getType() == outputType &&
@@ -2874,8 +2811,8 @@ static LogicalResult legalizeUnresolvedMaterialization(
 
   InFlightDiagnostic diag = op->emitError()
                             << "failed to legalize unresolved materialization "
-                               "from "
-                            << inputOperands.getTypes() << " to " << outputType
+                               "from ("
+                            << inputOperands.getTypes() << ") to " << outputType
                             << " that remained live after conversion";
   if (Operation *liveUser = findLiveUser(op->getUsers())) {
     diag.attachNote(liveUser->getLoc())
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index b35cda8e724f6..8254be68912c8 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -2,9 +2,8 @@
 
 
 func.func @test_invalid_arg_materialization(
-  // expected-error at below {{failed to materialize conversion for block argument #0 that remained live after conversion, type was 'i16'}}
+  // expected-error at below {{failed to legalize unresolved materialization from () to 'i16' that remained live after conversion}}
   %arg0: i16) {
-  // expected-note at below {{see existing live user here}}
   "foo.return"(%arg0) : (i16) -> ()
 }
 
@@ -104,9 +103,8 @@ func.func @test_block_argument_not_converted() {
 // Make sure argument type changes aren't implicitly forwarded.
 func.func @test_signature_conversion_no_converter() {
   "test.signature_conversion_no_converter"() ({
-  // expected-error at below {{failed to materialize conversion for block argument #0 that remained live after conversion}}
+  // expected-error at below {{failed to legalize unresolved materialization from ('f64') to 'f32' that remained live after conversion}}
   ^bb0(%arg0: f32):
-    // expected-note at below {{see existing live user here}}
     "test.type_consumer"(%arg0) : (f32) -> ()
     "test.return"(%arg0) : (f32) -> ()
   }) : () -> ()



More information about the Mlir-commits mailing list