[Mlir-commits] [mlir] [mlir][Transforms][NFC] Dialect Conversion: Move argument materialization logic (PR #96329)

Matthias Springer llvmlistbot at llvm.org
Mon Jun 24 23:44:48 PDT 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/96329

>From 50eec6bea675fb47aa36bdb49276632faff4aae7 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Thu, 20 Jun 2024 17:40:07 +0200
Subject: [PATCH] remove materializeLiveConversions

---
 .../Transforms/Utils/DialectConversion.cpp    | 133 +++++++-----------
 1 file changed, 52 insertions(+), 81 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 07ebd687ee2b3..47e03383304af 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -53,6 +53,16 @@ static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
   });
 }
 
+/// Helper function that computes an insertion point where the given value is
+/// defined and can be used without a dominance violation.
+static OpBuilder::InsertPoint computeInsertPoint(Value value) {
+  Block *insertBlock = value.getParentBlock();
+  Block::iterator insertPt = insertBlock->begin();
+  if (OpResult inputRes = dyn_cast<OpResult>(value))
+    insertPt = ++inputRes.getOwner()->getIterator();
+  return OpBuilder::InsertPoint(insertBlock, insertPt);
+}
+
 //===----------------------------------------------------------------------===//
 // ConversionValueMapping
 //===----------------------------------------------------------------------===//
@@ -445,11 +455,9 @@ class BlockTypeConversionRewrite : public BlockRewrite {
     return rewrite->getKind() == Kind::BlockTypeConversion;
   }
 
-  /// Materialize any necessary conversions for converted arguments that have
-  /// live users, using the provided `findLiveUser` to search for a user that
-  /// survives the conversion process.
-  LogicalResult
-  materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
+  Block *getOrigBlock() const { return origBlock; }
+
+  const TypeConverter *getConverter() const { return converter; }
 
   void commit(RewriterBase &rewriter) override;
 
@@ -841,14 +849,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// Build an unresolved materialization operation given an output type and set
   /// of input operands.
   Value buildUnresolvedMaterialization(MaterializationKind kind,
-                                       Block *insertBlock,
-                                       Block::iterator insertPt, Location loc,
+                                       OpBuilder::InsertPoint ip, Location loc,
                                        ValueRange inputs, Type outputType,
                                        Type origOutputType,
                                        const TypeConverter *converter);
-  Value buildUnresolvedTargetMaterialization(Location loc, Value input,
-                                             Type outputType,
-                                             const TypeConverter *converter);
 
   //===--------------------------------------------------------------------===//
   // Rewriter Notification Hooks
@@ -981,49 +985,6 @@ void BlockTypeConversionRewrite::rollback() {
   block->replaceAllUsesWith(origBlock);
 }
 
-LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
-    function_ref<Operation *(Value)> findLiveUser) {
-  // Process the remapping for each of the original arguments.
-  for (auto it : llvm::enumerate(origBlock->getArguments())) {
-    BlockArgument origArg = it.value();
-    // Note: `block` may be detached, so OpBuilder::atBlockBegin cannot be used.
-    OpBuilder builder(it.value().getContext(), /*listener=*/&rewriterImpl);
-    builder.setInsertionPointToStart(block);
-
-    // If the type of this argument changed and the argument is still live, we
-    // need to materialize a conversion.
-    if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
-      continue;
-    Operation *liveUser = findLiveUser(origArg);
-    if (!liveUser)
-      continue;
-
-    Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
-    assert(replacementValue && "replacement value not found");
-    Value newArg;
-    if (converter) {
-      builder.setInsertionPointAfterValue(replacementValue);
-      newArg = converter->materializeSourceConversion(
-          builder, origArg.getLoc(), origArg.getType(), replacementValue);
-      assert((!newArg || newArg.getType() == origArg.getType()) &&
-             "materialization hook did not provide a value of the expected "
-             "type");
-    }
-    if (!newArg) {
-      InFlightDiagnostic diag =
-          emitError(origArg.getLoc())
-          << "failed to materialize conversion for block argument #"
-          << it.index() << " that remained live after conversion, type was "
-          << origArg.getType();
-      diag.attachNote(liveUser->getLoc())
-          << "see existing live user here: " << *liveUser;
-      return failure();
-    }
-    rewriterImpl.mapping.map(origArg, newArg);
-  }
-  return success();
-}
-
 void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
   Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
   if (!repl)
@@ -1196,8 +1157,10 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
     Type newOperandType = newOperand.getType();
     if (currentTypeConverter && desiredType && newOperandType != desiredType) {
       Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
-      Value castValue = buildUnresolvedTargetMaterialization(
-          operandLoc, newOperand, desiredType, currentTypeConverter);
+      Value castValue = buildUnresolvedMaterialization(
+          MaterializationKind::Target, computeInsertPoint(newOperand),
+          operandLoc, /*inputs=*/newOperand, /*outputType=*/desiredType,
+          /*origArgType=*/{}, currentTypeConverter);
       mapping.map(mapping.lookupOrDefault(newOperand), castValue);
       newOperand = castValue;
     }
@@ -1325,8 +1288,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
       // This block argument was dropped and no replacement value was provided.
       // Materialize a replacement value "out of thin air".
       Value repl = buildUnresolvedMaterialization(
-          MaterializationKind::Source, newBlock, newBlock->begin(),
-          origArg.getLoc(), /*inputs=*/ValueRange(),
+          MaterializationKind::Source,
+          OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
+          /*inputs=*/ValueRange(),
           /*outputType=*/origArgType, /*origArgType=*/{}, converter);
       mapping.map(origArg, repl);
       appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
@@ -1351,8 +1315,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
     Value repl = buildUnresolvedMaterialization(
-        MaterializationKind::Argument, newBlock, newBlock->begin(),
-        origArg.getLoc(), /*inputs=*/replArgs,
+        MaterializationKind::Argument,
+        OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
+        /*inputs=*/replArgs,
         /*outputType=*/tryLegalizeType(origArgType), origArgType, converter);
     mapping.map(origArg, repl);
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
@@ -1374,8 +1339,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
 /// Build an unresolved materialization operation given an output type and set
 /// of input operands.
 Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
-    MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
-    Location loc, ValueRange inputs, Type outputType, Type origArgType,
+    MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
+    ValueRange inputs, Type outputType, Type origArgType,
     const TypeConverter *converter) {
   // Avoid materializing an unnecessary cast.
   if (inputs.size() == 1 && inputs.front().getType() == outputType)
@@ -1383,25 +1348,13 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
 
   // Create an unresolved materialization. We use a new OpBuilder to avoid
   // tracking the materialization like we do for other operations.
-  OpBuilder builder(insertBlock, insertPt);
+  OpBuilder builder(ip.getBlock(), ip.getPoint());
   auto convertOp =
       builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
   appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
                                                   origArgType);
   return convertOp.getResult(0);
 }
-Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
-    Location loc, Value input, Type outputType,
-    const TypeConverter *converter) {
-  Block *insertBlock = input.getParentBlock();
-  Block::iterator insertPt = insertBlock->begin();
-  if (OpResult inputRes = dyn_cast<OpResult>(input))
-    insertPt = ++inputRes.getOwner()->getIterator();
-
-  return buildUnresolvedMaterialization(
-      MaterializationKind::Target, insertBlock, insertPt, loc, input,
-      outputType, /*origArgType=*/{}, converter);
-}
 
 //===----------------------------------------------------------------------===//
 // Rewriter Notification Hooks
@@ -2515,9 +2468,9 @@ LogicalResult
 OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
   std::optional<DenseMap<Value, SmallVector<Value>>> inverseMapping;
   ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
-  if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
-                                                inverseMapping)) ||
-      failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
+  if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)) ||
+      failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
+                                                inverseMapping)))
     return failure();
 
   // Process requested operation replacements.
@@ -2573,10 +2526,28 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
        ++i) {
     auto &rewrite = rewriterImpl.rewrites[i];
     if (auto *blockTypeConversionRewrite =
-            dyn_cast<BlockTypeConversionRewrite>(rewrite.get()))
-      if (failed(blockTypeConversionRewrite->materializeLiveConversions(
-              findLiveUser)))
-        return failure();
+            dyn_cast<BlockTypeConversionRewrite>(rewrite.get())) {
+      // Process the remapping for each of the original arguments.
+      for (Value origArg :
+           blockTypeConversionRewrite->getOrigBlock()->getArguments()) {
+        // If the type of this argument changed and the argument is still live,
+        // we need to materialize a conversion.
+        if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
+          continue;
+        Operation *liveUser = findLiveUser(origArg);
+        if (!liveUser)
+          continue;
+
+        Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
+        assert(replacementValue && "replacement value not found");
+        Value repl = rewriterImpl.buildUnresolvedMaterialization(
+            MaterializationKind::Source, computeInsertPoint(replacementValue),
+            origArg.getLoc(), /*inputs=*/replacementValue,
+            /*outputType=*/origArg.getType(), /*origArgType=*/{},
+            blockTypeConversionRewrite->getConverter());
+        rewriterImpl.mapping.map(origArg, repl);
+      }
+    }
   }
   return success();
 }



More information about the Mlir-commits mailing list