[llvm-branch-commits] [mlir] [mlir][Transforms][NFC] Dialect Conversion: Move argument materialization logic (PR #96329)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Jun 21 09:47:58 PDT 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/96329

This commit moves the argument materialization logic from `legalizeConvertedArgumentTypes` to `legalizeUnresolvedMaterializations`.

Before this change:
- Argument materializations were created in `legalizeConvertedArgumentTypes` (which used to call `materializeLiveConversions`).

After this change:
- `legalizeConvertedArgumentTypes` creates a "placeholder" `unrealized_conversion_cast`.
- The placeholder `unrealized_conversion_cast` is replaced with an argument materialization (using the type converter) in `legalizeUnresolvedMaterializations`.
- All argument and target materializations now take place in the same location (`legalizeUnresolvedMaterializations`).

This commit brings us closer towards creating all source/target/argument materializations in one central step, which can then be made optional (and delegated to the user) in the future. (There is one more source materialization step that has not been moved yet.)

This commit also consolidates all `build*UnresolvedMaterialization` functions into a single `buildUnresolvedMaterialization` function.



>From d2b0b9ef97c626fc48b0c00ce2ec8e5573599f2b Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Thu, 20 Jun 2024 17:40:07 +0200
Subject: [PATCH] remove materializeLiveConversions

---
 .../Transforms/Utils/DialectConversion.cpp    | 133 +++++++-----------
 1 file changed, 52 insertions(+), 81 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 07ebd687ee2b3..47e03383304af 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -53,6 +53,16 @@ static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
   });
 }
 
+/// Helper function that computes an insertion point where the given value is
+/// defined and can be used without a dominance violation.
+static OpBuilder::InsertPoint computeInsertPoint(Value value) {
+  Block *insertBlock = value.getParentBlock();
+  Block::iterator insertPt = insertBlock->begin();
+  if (OpResult inputRes = dyn_cast<OpResult>(value))
+    insertPt = ++inputRes.getOwner()->getIterator();
+  return OpBuilder::InsertPoint(insertBlock, insertPt);
+}
+
 //===----------------------------------------------------------------------===//
 // ConversionValueMapping
 //===----------------------------------------------------------------------===//
@@ -445,11 +455,9 @@ class BlockTypeConversionRewrite : public BlockRewrite {
     return rewrite->getKind() == Kind::BlockTypeConversion;
   }
 
-  /// Materialize any necessary conversions for converted arguments that have
-  /// live users, using the provided `findLiveUser` to search for a user that
-  /// survives the conversion process.
-  LogicalResult
-  materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
+  Block *getOrigBlock() const { return origBlock; }
+
+  const TypeConverter *getConverter() const { return converter; }
 
   void commit(RewriterBase &rewriter) override;
 
@@ -841,14 +849,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// Build an unresolved materialization operation given an output type and set
   /// of input operands.
   Value buildUnresolvedMaterialization(MaterializationKind kind,
-                                       Block *insertBlock,
-                                       Block::iterator insertPt, Location loc,
+                                       OpBuilder::InsertPoint ip, Location loc,
                                        ValueRange inputs, Type outputType,
                                        Type origOutputType,
                                        const TypeConverter *converter);
-  Value buildUnresolvedTargetMaterialization(Location loc, Value input,
-                                             Type outputType,
-                                             const TypeConverter *converter);
 
   //===--------------------------------------------------------------------===//
   // Rewriter Notification Hooks
@@ -981,49 +985,6 @@ void BlockTypeConversionRewrite::rollback() {
   block->replaceAllUsesWith(origBlock);
 }
 
-LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
-    function_ref<Operation *(Value)> findLiveUser) {
-  // Process the remapping for each of the original arguments.
-  for (auto it : llvm::enumerate(origBlock->getArguments())) {
-    BlockArgument origArg = it.value();
-    // Note: `block` may be detached, so OpBuilder::atBlockBegin cannot be used.
-    OpBuilder builder(it.value().getContext(), /*listener=*/&rewriterImpl);
-    builder.setInsertionPointToStart(block);
-
-    // If the type of this argument changed and the argument is still live, we
-    // need to materialize a conversion.
-    if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
-      continue;
-    Operation *liveUser = findLiveUser(origArg);
-    if (!liveUser)
-      continue;
-
-    Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
-    assert(replacementValue && "replacement value not found");
-    Value newArg;
-    if (converter) {
-      builder.setInsertionPointAfterValue(replacementValue);
-      newArg = converter->materializeSourceConversion(
-          builder, origArg.getLoc(), origArg.getType(), replacementValue);
-      assert((!newArg || newArg.getType() == origArg.getType()) &&
-             "materialization hook did not provide a value of the expected "
-             "type");
-    }
-    if (!newArg) {
-      InFlightDiagnostic diag =
-          emitError(origArg.getLoc())
-          << "failed to materialize conversion for block argument #"
-          << it.index() << " that remained live after conversion, type was "
-          << origArg.getType();
-      diag.attachNote(liveUser->getLoc())
-          << "see existing live user here: " << *liveUser;
-      return failure();
-    }
-    rewriterImpl.mapping.map(origArg, newArg);
-  }
-  return success();
-}
-
 void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
   Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
   if (!repl)
@@ -1196,8 +1157,10 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
     Type newOperandType = newOperand.getType();
     if (currentTypeConverter && desiredType && newOperandType != desiredType) {
       Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
-      Value castValue = buildUnresolvedTargetMaterialization(
-          operandLoc, newOperand, desiredType, currentTypeConverter);
+      Value castValue = buildUnresolvedMaterialization(
+          MaterializationKind::Target, computeInsertPoint(newOperand),
+          operandLoc, /*inputs=*/newOperand, /*outputType=*/desiredType,
+          /*origArgType=*/{}, currentTypeConverter);
       mapping.map(mapping.lookupOrDefault(newOperand), castValue);
       newOperand = castValue;
     }
@@ -1325,8 +1288,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
       // This block argument was dropped and no replacement value was provided.
       // Materialize a replacement value "out of thin air".
       Value repl = buildUnresolvedMaterialization(
-          MaterializationKind::Source, newBlock, newBlock->begin(),
-          origArg.getLoc(), /*inputs=*/ValueRange(),
+          MaterializationKind::Source,
+          OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
+          /*inputs=*/ValueRange(),
           /*outputType=*/origArgType, /*origArgType=*/{}, converter);
       mapping.map(origArg, repl);
       appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
@@ -1351,8 +1315,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
     Value repl = buildUnresolvedMaterialization(
-        MaterializationKind::Argument, newBlock, newBlock->begin(),
-        origArg.getLoc(), /*inputs=*/replArgs,
+        MaterializationKind::Argument,
+        OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
+        /*inputs=*/replArgs,
         /*outputType=*/tryLegalizeType(origArgType), origArgType, converter);
     mapping.map(origArg, repl);
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
@@ -1374,8 +1339,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
 /// Build an unresolved materialization operation given an output type and set
 /// of input operands.
 Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
-    MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
-    Location loc, ValueRange inputs, Type outputType, Type origArgType,
+    MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
+    ValueRange inputs, Type outputType, Type origArgType,
     const TypeConverter *converter) {
   // Avoid materializing an unnecessary cast.
   if (inputs.size() == 1 && inputs.front().getType() == outputType)
@@ -1383,25 +1348,13 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
 
   // Create an unresolved materialization. We use a new OpBuilder to avoid
   // tracking the materialization like we do for other operations.
-  OpBuilder builder(insertBlock, insertPt);
+  OpBuilder builder(ip.getBlock(), ip.getPoint());
   auto convertOp =
       builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
   appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
                                                   origArgType);
   return convertOp.getResult(0);
 }
-Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
-    Location loc, Value input, Type outputType,
-    const TypeConverter *converter) {
-  Block *insertBlock = input.getParentBlock();
-  Block::iterator insertPt = insertBlock->begin();
-  if (OpResult inputRes = dyn_cast<OpResult>(input))
-    insertPt = ++inputRes.getOwner()->getIterator();
-
-  return buildUnresolvedMaterialization(
-      MaterializationKind::Target, insertBlock, insertPt, loc, input,
-      outputType, /*origArgType=*/{}, converter);
-}
 
 //===----------------------------------------------------------------------===//
 // Rewriter Notification Hooks
@@ -2515,9 +2468,9 @@ LogicalResult
 OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
   std::optional<DenseMap<Value, SmallVector<Value>>> inverseMapping;
   ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
-  if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
-                                                inverseMapping)) ||
-      failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
+  if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)) ||
+      failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
+                                                inverseMapping)))
     return failure();
 
   // Process requested operation replacements.
@@ -2573,10 +2526,28 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
        ++i) {
     auto &rewrite = rewriterImpl.rewrites[i];
     if (auto *blockTypeConversionRewrite =
-            dyn_cast<BlockTypeConversionRewrite>(rewrite.get()))
-      if (failed(blockTypeConversionRewrite->materializeLiveConversions(
-              findLiveUser)))
-        return failure();
+            dyn_cast<BlockTypeConversionRewrite>(rewrite.get())) {
+      // Process the remapping for each of the original arguments.
+      for (Value origArg :
+           blockTypeConversionRewrite->getOrigBlock()->getArguments()) {
+        // If the type of this argument changed and the argument is still live,
+        // we need to materialize a conversion.
+        if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
+          continue;
+        Operation *liveUser = findLiveUser(origArg);
+        if (!liveUser)
+          continue;
+
+        Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
+        assert(replacementValue && "replacement value not found");
+        Value repl = rewriterImpl.buildUnresolvedMaterialization(
+            MaterializationKind::Source, computeInsertPoint(replacementValue),
+            origArg.getLoc(), /*inputs=*/replacementValue,
+            /*outputType=*/origArg.getType(), /*origArgType=*/{},
+            blockTypeConversionRewrite->getConverter());
+        rewriterImpl.mapping.map(origArg, repl);
+      }
+    }
   }
   return success();
 }



More information about the llvm-branch-commits mailing list