[Mlir-commits] [mlir] [mlir][Transforms] Dialect conversion: extra signature conversion check (PR #117471)

Matthias Springer llvmlistbot at llvm.org
Sun Nov 24 00:39:32 PST 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/117471

This commit adds an extra assertion to `applySignatureConversion` to prevent incorrect API usage: The same block cannot be converted multiple times. That would mess with the underlying conversion value mapping. (Mappings would be overwritten.) This is similar to op replacements: The same op cannot be replaced multiple times.

To simplify the check, `BlockTypeConversionRewrite::block` now stores the original block. The new block is stored in an extra field. (It used to be the other way around.)

This commit is in preparation of adding 1:N support to the conversion value mapping. Before making any further changes to the mapping infrastructure, I'd like to make sure that the code base around it (that uses the mapping) is robust.


>From a7073f8cc92008917a9c65c42840894f9df685bc Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sun, 24 Nov 2024 09:33:18 +0100
Subject: [PATCH] [mlir][Transforms] Dialect conversion: extra signature
 conversion checks

---
 .../Transforms/Utils/DialectConversion.cpp    | 35 ++++++++++++++-----
 1 file changed, 26 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 5acd095da8e386..710c976281dc3d 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -434,23 +434,25 @@ class MoveBlockRewrite : public BlockRewrite {
 class BlockTypeConversionRewrite : public BlockRewrite {
 public:
   BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
-                             Block *block, Block *origBlock)
-      : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
-        origBlock(origBlock) {}
+                             Block *origBlock, Block *newBlock)
+      : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock),
+        newBlock(newBlock) {}
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::BlockTypeConversion;
   }
 
-  Block *getOrigBlock() const { return origBlock; }
+  Block *getOrigBlock() const { return block; }
+
+  Block *getNewBlock() const { return newBlock; }
 
   void commit(RewriterBase &rewriter) override;
 
   void rollback() override;
 
 private:
-  /// The original block that was requested to have its signature converted.
-  Block *origBlock;
+  /// The new block that was created as part of this signature conversion.
+  Block *newBlock;
 };
 
 /// Replacing a block argument. This rewrite is not immediately reflected in the
@@ -721,6 +723,18 @@ static bool hasRewrite(R &&rewrites, Operation *op) {
   });
 }
 
+#ifndef NDEBUG
+/// Return "true" if there is a block rewrite that matches the specified
+/// rewrite type and block among the given rewrites.
+template <typename RewriteTy, typename R>
+static bool hasRewrite(R &&rewrites, Block *block) {
+  return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
+    auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
+    return rewriteTy && rewriteTy->getBlock() == block;
+  });
+}
+#endif // NDEBUG
+
 //===----------------------------------------------------------------------===//
 // ConversionPatternRewriterImpl
 //===----------------------------------------------------------------------===//
@@ -966,12 +980,12 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
   // block.
   if (auto *listener =
           dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
-    for (Operation *op : block->getUsers())
+    for (Operation *op : getNewBlock()->getUsers())
       listener->notifyOperationModified(op);
 }
 
 void BlockTypeConversionRewrite::rollback() {
-  block->replaceAllUsesWith(origBlock);
+  getNewBlock()->replaceAllUsesWith(getOrigBlock());
 }
 
 void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
@@ -1223,6 +1237,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     ConversionPatternRewriter &rewriter, Block *block,
     const TypeConverter *converter,
     TypeConverter::SignatureConversion &signatureConversion) {
+  // A block cannot be converted multiple times.
+  assert(!hasRewrite<BlockTypeConversionRewrite>(rewrites, block) &&
+         "block was already converted");
   OpBuilder::InsertionGuard g(rewriter);
 
   // If no arguments are being changed or added, there is nothing to do.
@@ -1308,7 +1325,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
   }
 
-  appendRewrite<BlockTypeConversionRewrite>(newBlock, block);
+  appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
 
   // Erase the old block. (It is just unlinked for now and will be erased during
   // cleanup.)



More information about the Mlir-commits mailing list