[Mlir-commits] [mlir] [mlir][Transforms][NFC] Dialect Conversion: Move argument materialization logic (PR #98805)

Matthias Springer llvmlistbot at llvm.org
Sat Jul 20 03:18:43 PDT 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/98805

>From 2c5cc19206e737c8fab549f8b5daa78654c257d1 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Thu, 20 Jun 2024 17:40:07 +0200
Subject: [PATCH] [mlir][Transforms][NFC] Dialect Conversion: Move argument
 materialization logic

This commit moves the argument materialization logic from `legalizeConvertedArgumentTypes` to `legalizeUnresolvedMaterializations`.

Before this change:
- Argument materializations were created in `legalizeConvertedArgumentTypes` (which used to call `materializeLiveConversions`).

After this change:
- `legalizeConvertedArgumentTypes` creates a "placeholder" `unrealized_conversion_cast`.
- The placeholder `unrealized_conversion_cast` is replaced with an argument materialization (using the type converter) in `legalizeUnresolvedMaterializations`.
- All argument and target materializations now take place in the same location (`legalizeUnresolvedMaterializations`).

This commit brings us closer towards creating all source/target/argument materializations in one central step, which can then be made optional (and delegated to the user) in the future. (There is one more source materialization step that has not been moved yet.)

This commit also consolidates all `build*UnresolvedMaterialization` functions into a single `buildUnresolvedMaterialization` function.

This is a re-upload of #96329.
---
 .../Transforms/Utils/DialectConversion.cpp    | 138 +++++++-----------
 1 file changed, 54 insertions(+), 84 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 0b552a7e1ca3b..a2e7570380351 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -53,6 +53,16 @@ static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
   });
 }
 
+/// Helper function that computes an insertion point where the given value is
+/// defined and can be used without a dominance violation.
+static OpBuilder::InsertPoint computeInsertPoint(Value value) {
+  Block *insertBlock = value.getParentBlock();
+  Block::iterator insertPt = insertBlock->begin();
+  if (OpResult inputRes = dyn_cast<OpResult>(value))
+    insertPt = ++inputRes.getOwner()->getIterator();
+  return OpBuilder::InsertPoint(insertBlock, insertPt);
+}
+
 //===----------------------------------------------------------------------===//
 // ConversionValueMapping
 //===----------------------------------------------------------------------===//
@@ -445,11 +455,9 @@ class BlockTypeConversionRewrite : public BlockRewrite {
     return rewrite->getKind() == Kind::BlockTypeConversion;
   }
 
-  /// Materialize any necessary conversions for converted arguments that have
-  /// live users, using the provided `findLiveUser` to search for a user that
-  /// survives the conversion process.
-  LogicalResult
-  materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
+  Block *getOrigBlock() const { return origBlock; }
+
+  const TypeConverter *getConverter() const { return converter; }
 
   void commit(RewriterBase &rewriter) override;
 
@@ -830,15 +838,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// Build an unresolved materialization operation given an output type and set
   /// of input operands.
   Value buildUnresolvedMaterialization(MaterializationKind kind,
-                                       Block *insertBlock,
-                                       Block::iterator insertPt, Location loc,
+                                       OpBuilder::InsertPoint ip, Location loc,
                                        ValueRange inputs, Type outputType,
                                        const TypeConverter *converter);
 
-  Value buildUnresolvedTargetMaterialization(Location loc, Value input,
-                                             Type outputType,
-                                             const TypeConverter *converter);
-
   //===--------------------------------------------------------------------===//
   // Rewriter Notification Hooks
   //===--------------------------------------------------------------------===//
@@ -970,49 +973,6 @@ void BlockTypeConversionRewrite::rollback() {
   block->replaceAllUsesWith(origBlock);
 }
 
-LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
-    function_ref<Operation *(Value)> findLiveUser) {
-  // Process the remapping for each of the original arguments.
-  for (auto it : llvm::enumerate(origBlock->getArguments())) {
-    BlockArgument origArg = it.value();
-    // Note: `block` may be detached, so OpBuilder::atBlockBegin cannot be used.
-    OpBuilder builder(it.value().getContext(), /*listener=*/&rewriterImpl);
-    builder.setInsertionPointToStart(block);
-
-    // If the type of this argument changed and the argument is still live, we
-    // need to materialize a conversion.
-    if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
-      continue;
-    Operation *liveUser = findLiveUser(origArg);
-    if (!liveUser)
-      continue;
-
-    Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
-    assert(replacementValue && "replacement value not found");
-    Value newArg;
-    if (converter) {
-      builder.setInsertionPointAfterValue(replacementValue);
-      newArg = converter->materializeSourceConversion(
-          builder, origArg.getLoc(), origArg.getType(), replacementValue);
-      assert((!newArg || newArg.getType() == origArg.getType()) &&
-             "materialization hook did not provide a value of the expected "
-             "type");
-    }
-    if (!newArg) {
-      InFlightDiagnostic diag =
-          emitError(origArg.getLoc())
-          << "failed to materialize conversion for block argument #"
-          << it.index() << " that remained live after conversion, type was "
-          << origArg.getType();
-      diag.attachNote(liveUser->getLoc())
-          << "see existing live user here: " << *liveUser;
-      return failure();
-    }
-    rewriterImpl.mapping.map(origArg, newArg);
-  }
-  return success();
-}
-
 void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
   Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
   if (!repl)
@@ -1185,8 +1145,10 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
     Type newOperandType = newOperand.getType();
     if (currentTypeConverter && desiredType && newOperandType != desiredType) {
       Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
-      Value castValue = buildUnresolvedTargetMaterialization(
-          operandLoc, newOperand, desiredType, currentTypeConverter);
+      Value castValue = buildUnresolvedMaterialization(
+          MaterializationKind::Target, computeInsertPoint(newOperand),
+          operandLoc, /*inputs=*/newOperand, /*outputType=*/desiredType,
+          currentTypeConverter);
       mapping.map(mapping.lookupOrDefault(newOperand), castValue);
       newOperand = castValue;
     }
@@ -1299,8 +1261,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
       // This block argument was dropped and no replacement value was provided.
       // Materialize a replacement value "out of thin air".
       Value repl = buildUnresolvedMaterialization(
-          MaterializationKind::Source, newBlock, newBlock->begin(),
-          origArg.getLoc(), /*inputs=*/ValueRange(),
+          MaterializationKind::Source,
+          OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
+          /*inputs=*/ValueRange(),
           /*outputType=*/origArgType, converter);
       mapping.map(origArg, repl);
       appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
@@ -1324,8 +1287,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
     Value argMat = buildUnresolvedMaterialization(
-        MaterializationKind::Argument, newBlock, newBlock->begin(),
-        origArg.getLoc(), /*inputs=*/replArgs, origArgType, converter);
+        MaterializationKind::Argument,
+        OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
+        /*inputs=*/replArgs, origArgType, converter);
     mapping.map(origArg, argMat);
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
 
@@ -1339,7 +1303,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     if (converter)
       legalOutputType = converter->convertType(origArgType);
     if (legalOutputType && legalOutputType != origArgType) {
-      Value targetMat = buildUnresolvedTargetMaterialization(
+      Value targetMat = buildUnresolvedMaterialization(
+          MaterializationKind::Target, computeInsertPoint(argMat),
           origArg.getLoc(), argMat, legalOutputType, converter);
       mapping.map(argMat, targetMat);
     }
@@ -1362,33 +1327,20 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
 /// Build an unresolved materialization operation given an output type and set
 /// of input operands.
 Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
-    MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
-    Location loc, ValueRange inputs, Type outputType,
-    const TypeConverter *converter) {
+    MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
+    ValueRange inputs, Type outputType, const TypeConverter *converter) {
   // Avoid materializing an unnecessary cast.
   if (inputs.size() == 1 && inputs.front().getType() == outputType)
     return inputs.front();
 
   // Create an unresolved materialization. We use a new OpBuilder to avoid
   // tracking the materialization like we do for other operations.
-  OpBuilder builder(insertBlock, insertPt);
+  OpBuilder builder(ip.getBlock(), ip.getPoint());
   auto convertOp =
       builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
   appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
   return convertOp.getResult(0);
 }
-Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
-    Location loc, Value input, Type outputType,
-    const TypeConverter *converter) {
-  Block *insertBlock = input.getParentBlock();
-  Block::iterator insertPt = insertBlock->begin();
-  if (OpResult inputRes = dyn_cast<OpResult>(input))
-    insertPt = ++inputRes.getOwner()->getIterator();
-
-  return buildUnresolvedMaterialization(MaterializationKind::Target,
-                                        insertBlock, insertPt, loc, input,
-                                        outputType, converter);
-}
 
 //===----------------------------------------------------------------------===//
 // Rewriter Notification Hooks
@@ -2502,9 +2454,9 @@ LogicalResult
 OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
   std::optional<DenseMap<Value, SmallVector<Value>>> inverseMapping;
   ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
-  if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
-                                                inverseMapping)) ||
-      failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
+  if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)) ||
+      failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
+                                                inverseMapping)))
     return failure();
 
   // Process requested operation replacements.
@@ -2560,10 +2512,28 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
        ++i) {
     auto &rewrite = rewriterImpl.rewrites[i];
     if (auto *blockTypeConversionRewrite =
-            dyn_cast<BlockTypeConversionRewrite>(rewrite.get()))
-      if (failed(blockTypeConversionRewrite->materializeLiveConversions(
-              findLiveUser)))
-        return failure();
+            dyn_cast<BlockTypeConversionRewrite>(rewrite.get())) {
+      // Process the remapping for each of the original arguments.
+      for (Value origArg :
+           blockTypeConversionRewrite->getOrigBlock()->getArguments()) {
+        // If the type of this argument changed and the argument is still live,
+        // we need to materialize a conversion.
+        if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
+          continue;
+        Operation *liveUser = findLiveUser(origArg);
+        if (!liveUser)
+          continue;
+
+        Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
+        assert(replacementValue && "replacement value not found");
+        Value repl = rewriterImpl.buildUnresolvedMaterialization(
+            MaterializationKind::Source, computeInsertPoint(replacementValue),
+            origArg.getLoc(), /*inputs=*/replacementValue,
+            /*outputType=*/origArg.getType(),
+            blockTypeConversionRewrite->getConverter());
+        rewriterImpl.mapping.map(origArg, repl);
+      }
+    }
   }
   return success();
 }



More information about the Mlir-commits mailing list