[llvm-branch-commits] [mlir] [mlir][Transforms] Detect mapping overwrites during block signature conversion (PR #121646)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Sat Jan 4 06:07:04 PST 2025


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/121646

>From bf57b8d0a3da1c9d383374399a36f766df3f255e Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 4 Jan 2025 13:53:38 +0100
Subject: [PATCH] [mlir][Transforms] Detect mapping overwrites during block
 signature conversion

Add extra assertions to make sure that a value in the conversion value mapping is not overwritten during `applySignatureConversion`.
---
 mlir/lib/Transforms/Utils/DialectConversion.cpp | 11 ++++++++---
 1 file changed, 8 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 4904d3ce3f8635..94e61a255dd3be 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -176,6 +176,8 @@ struct ConversionValueMapping {
   template <typename OldVal, typename NewVal>
   std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
   map(OldVal &&oldVal, NewVal &&newVal) {
+    assert(!mapping.contains(oldVal) &&
+           "attempting to overwrite existing mapping");
     LLVM_DEBUG({
       ValueVector next(newVal);
       while (true) {
@@ -1412,6 +1414,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
   for (unsigned i = 0; i != origArgCount; ++i) {
     BlockArgument origArg = block->getArgument(i);
     Type origArgType = origArg.getType();
+    ValueVector currentMapping = mapping.lookupOrDefault(origArg);
 
     std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
         signatureConversion.getInputMapping(i);
@@ -1421,7 +1424,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
       buildUnresolvedMaterialization(
           MaterializationKind::Source,
           OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
-          /*valuesToMap=*/{origArg}, /*inputs=*/ValueRange(),
+          /*valuesToMap=*/currentMapping, /*inputs=*/ValueRange(),
           /*outputType=*/origArgType, /*originalType=*/Type(), converter);
       appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
       continue;
@@ -1432,7 +1435,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
       assert(inputMap->size == 0 &&
              "invalid to provide a replacement value when the argument isn't "
              "dropped");
-      mapping.map(origArg, repl);
+      mapping.map(currentMapping, repl);
       appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
       continue;
     }
@@ -1441,7 +1444,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
     ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs);
-    mapping.map(origArg, std::move(replArgVals));
+    mapping.map(currentMapping, std::move(replArgVals));
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
   }
 
@@ -1757,6 +1760,8 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
                              << "'(in region of '" << parentOp->getName()
                              << "'(" << from.getOwner()->getParentOp() << ")\n";
   });
+  llvm::errs() << "replaceUsesOfBlockArgument: " << from.getOwner() << "\n";
+
   impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
                                               impl->currentTypeConverter);
   impl->mapping.map(impl->mapping.lookupOrDefault(from), to);



More information about the llvm-branch-commits mailing list