[Mlir-commits] [mlir] f2d500c - [mlir][Transforms] Dialect conversion: Fix bug in `UnresolvedMaterializationRewrite` rollback (#105949)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 28 18:12:30 PST 2024


Author: Matthias Springer
Date: 2024-11-29T11:12:27+09:00
New Revision: f2d500c61701fc50f5c0c2cd9660a93e15ecc9b9

URL: https://github.com/llvm/llvm-project/commit/f2d500c61701fc50f5c0c2cd9660a93e15ecc9b9
DIFF: https://github.com/llvm/llvm-project/commit/f2d500c61701fc50f5c0c2cd9660a93e15ecc9b9.diff

LOG: [mlir][Transforms] Dialect conversion: Fix bug in `UnresolvedMaterializationRewrite` rollback (#105949)

When an unresolved materialization (`unrealized_conversion_cast` op) is
rolled back, the mapping should be rolled back as well, regardless of
whether it is a source, target or argument materialization. Otherwise,
we accumulate pointers to erased IR in the `mapping`. This is harmless
in most cases, but can cause issues when a new operation is allocated at
the same memory location and the pointer is "reused".

It is not possible to write a test case for this because I cannot
trigger the pointer reuse programmatically.

Added: 
    

Modified: 
    mlir/lib/Transforms/Utils/DialectConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 710c976281dc3d..1424c4974f2d43 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -676,7 +676,8 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
   UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
                                    UnrealizedConversionCastOp op,
                                    const TypeConverter *converter,
-                                   MaterializationKind kind, Type originalType);
+                                   MaterializationKind kind, Type originalType,
+                                   Value mappedValue);
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -710,6 +711,10 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
   /// The original type of the SSA value. Only used for target
   /// materializations.
   Type originalType;
+
+  /// The value in the conversion value mapping that is being replaced by the
+  /// results of this unresolved materialization.
+  Value mappedValue;
 };
 } // namespace
 
@@ -814,10 +819,14 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
 
   /// Build an unresolved materialization operation given an output type and set
   /// of input operands.
+  ///
+  /// If `valueToMap` is set to a non-null Value, then that value is mapped to
+  /// the result of the unresolved materialization in the conversion value
+  /// mapping.
   Value buildUnresolvedMaterialization(MaterializationKind kind,
                                        OpBuilder::InsertPoint ip, Location loc,
-                                       ValueRange inputs, Type outputType,
-                                       Type originalType,
+                                       Value valueToMap, ValueRange inputs,
+                                       Type outputType, Type originalType,
                                        const TypeConverter *converter);
 
   /// Build an N:1 materialization for the given original value that was
@@ -1068,19 +1077,19 @@ void CreateOperationRewrite::rollback() {
 
 UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
     ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
-    const TypeConverter *converter, MaterializationKind kind, Type originalType)
+    const TypeConverter *converter, MaterializationKind kind, Type originalType,
+    Value mappedValue)
     : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
-      converterAndKind(converter, kind), originalType(originalType) {
+      converterAndKind(converter, kind), originalType(originalType),
+      mappedValue(mappedValue) {
   assert((!originalType || kind == MaterializationKind::Target) &&
          "original type is valid only for target materializations");
   rewriterImpl.unresolvedMaterializations[op] = this;
 }
 
 void UnresolvedMaterializationRewrite::rollback() {
-  if (getMaterializationKind() == MaterializationKind::Target) {
-    for (Value input : op->getOperands())
-      rewriterImpl.mapping.erase(input);
-  }
+  if (mappedValue)
+    rewriterImpl.mapping.erase(mappedValue);
   rewriterImpl.unresolvedMaterializations.erase(getOperation());
   op->erase();
 }
@@ -1176,10 +1185,9 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
       // source materialization was created yet.
       Value castValue = buildUnresolvedMaterialization(
           MaterializationKind::Target, computeInsertPoint(newOperand),
-          operandLoc,
-          /*inputs=*/newOperand, /*outputType=*/desiredType,
-          /*originalType=*/origType, currentTypeConverter);
-      mapping.map(newOperand, castValue);
+          operandLoc, /*valueToMap=*/newOperand, /*inputs=*/newOperand,
+          /*outputType=*/desiredType, /*originalType=*/origType,
+          currentTypeConverter);
       newOperand = castValue;
     }
     remapped.push_back(newOperand);
@@ -1293,12 +1301,11 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     if (!inputMap) {
       // This block argument was dropped and no replacement value was provided.
       // Materialize a replacement value "out of thin air".
-      Value repl = buildUnresolvedMaterialization(
+      buildUnresolvedMaterialization(
           MaterializationKind::Source,
           OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
-          /*inputs=*/ValueRange(),
+          /*valueToMap=*/origArg, /*inputs=*/ValueRange(),
           /*outputType=*/origArgType, /*originalType=*/Type(), converter);
-      mapping.map(origArg, repl);
       appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
       continue;
     }
@@ -1342,14 +1349,17 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
 /// of input operands.
 Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
     MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
-    ValueRange inputs, Type outputType, Type originalType,
+    Value valueToMap, ValueRange inputs, Type outputType, Type originalType,
     const TypeConverter *converter) {
   assert((!originalType || kind == MaterializationKind::Target) &&
          "original type is valid only for target materializations");
 
   // Avoid materializing an unnecessary cast.
-  if (inputs.size() == 1 && inputs.front().getType() == outputType)
+  if (inputs.size() == 1 && inputs.front().getType() == outputType) {
+    if (valueToMap)
+      mapping.map(valueToMap, inputs.front());
     return inputs.front();
+  }
 
   // Create an unresolved materialization. We use a new OpBuilder to avoid
   // tracking the materialization like we do for other operations.
@@ -1357,8 +1367,10 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
   auto convertOp =
       builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
+  if (valueToMap)
+    mapping.map(valueToMap, convertOp.getResult(0));
   appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
-                                                  originalType);
+                                                  originalType, valueToMap);
   return convertOp.getResult(0);
 }
 
@@ -1367,11 +1379,10 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
     Value originalValue, const TypeConverter *converter) {
   // Insert argument materialization back to the original type.
   Type originalType = originalValue.getType();
-  Value argMat =
-      buildUnresolvedMaterialization(MaterializationKind::Argument, ip, loc,
-                                     /*inputs=*/replacements, originalType,
-                                     /*originalType=*/Type(), converter);
-  mapping.map(originalValue, argMat);
+  Value argMat = buildUnresolvedMaterialization(
+      MaterializationKind::Argument, ip, loc, /*valueToMap=*/originalValue,
+      /*inputs=*/replacements, originalType, /*originalType=*/Type(),
+      converter);
 
   // Insert target materialization to the legalized type.
   Type legalOutputType;
@@ -1387,11 +1398,11 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
     legalOutputType = replacements[0].getType();
   }
   if (legalOutputType && legalOutputType != originalType) {
-    Value targetMat = buildUnresolvedMaterialization(
-        MaterializationKind::Target, computeInsertPoint(argMat), loc,
-        /*inputs=*/argMat, /*outputType=*/legalOutputType,
-        /*originalType=*/originalType, converter);
-    mapping.map(argMat, targetMat);
+    buildUnresolvedMaterialization(MaterializationKind::Target,
+                                   computeInsertPoint(argMat), loc,
+                                   /*valueToMap=*/argMat, /*inputs=*/argMat,
+                                   /*outputType=*/legalOutputType,
+                                   /*originalType=*/originalType, converter);
   }
 }
 
@@ -1425,9 +1436,8 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
   }
   Value castValue = buildUnresolvedMaterialization(
       MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(),
-      /*inputs=*/repl, /*outputType=*/value.getType(),
+      /*valueToMap=*/value, /*inputs=*/repl, /*outputType=*/value.getType(),
       /*originalType=*/Type(), converter);
-  mapping.map(value, castValue);
   return castValue;
 }
 
@@ -1480,7 +1490,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
       // Materialize a replacement value "out of thin air".
       Value sourceMat = buildUnresolvedMaterialization(
           MaterializationKind::Source, computeInsertPoint(result),
-          result.getLoc(), /*inputs=*/ValueRange(),
+          result.getLoc(), /*valueToMap=*/Value(), /*inputs=*/ValueRange(),
           /*outputType=*/result.getType(), /*originalType=*/Type(),
           currentTypeConverter);
       repl.push_back(sourceMat);


        


More information about the Mlir-commits mailing list