[Mlir-commits] [mlir] [mlir][Transforms] Dialect conversion: extra signature conversion check (PR #117471)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Nov 24 00:40:58 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

This commit adds an extra assertion to `applySignatureConversion` to prevent incorrect API usage: The same block cannot be converted multiple times. That would mess with the underlying conversion value mapping. (Mappings would be overwritten.) This is similar to op replacements: The same op cannot be replaced multiple times.

To simplify the check, `BlockTypeConversionRewrite::block` now stores the original block. The new block is stored in an extra field. (It used to be the other way around.)

This commit is in preparation of adding 1:N support to the conversion value mapping. Before making any further changes to the mapping infrastructure, I'd like to make sure that the code base around it (that uses the mapping) is robust.


---
Full diff: https://github.com/llvm/llvm-project/pull/117471.diff


1 Files Affected:

- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+26-9) 


``````````diff
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 5acd095da8e386..710c976281dc3d 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -434,23 +434,25 @@ class MoveBlockRewrite : public BlockRewrite {
 class BlockTypeConversionRewrite : public BlockRewrite {
 public:
   BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
-                             Block *block, Block *origBlock)
-      : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
-        origBlock(origBlock) {}
+                             Block *origBlock, Block *newBlock)
+      : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock),
+        newBlock(newBlock) {}
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::BlockTypeConversion;
   }
 
-  Block *getOrigBlock() const { return origBlock; }
+  Block *getOrigBlock() const { return block; }
+
+  Block *getNewBlock() const { return newBlock; }
 
   void commit(RewriterBase &rewriter) override;
 
   void rollback() override;
 
 private:
-  /// The original block that was requested to have its signature converted.
-  Block *origBlock;
+  /// The new block that was created as part of this signature conversion.
+  Block *newBlock;
 };
 
 /// Replacing a block argument. This rewrite is not immediately reflected in the
@@ -721,6 +723,18 @@ static bool hasRewrite(R &&rewrites, Operation *op) {
   });
 }
 
+#ifndef NDEBUG
+/// Return "true" if there is a block rewrite that matches the specified
+/// rewrite type and block among the given rewrites.
+template <typename RewriteTy, typename R>
+static bool hasRewrite(R &&rewrites, Block *block) {
+  return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
+    auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
+    return rewriteTy && rewriteTy->getBlock() == block;
+  });
+}
+#endif // NDEBUG
+
 //===----------------------------------------------------------------------===//
 // ConversionPatternRewriterImpl
 //===----------------------------------------------------------------------===//
@@ -966,12 +980,12 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
   // block.
   if (auto *listener =
           dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
-    for (Operation *op : block->getUsers())
+    for (Operation *op : getNewBlock()->getUsers())
       listener->notifyOperationModified(op);
 }
 
 void BlockTypeConversionRewrite::rollback() {
-  block->replaceAllUsesWith(origBlock);
+  getNewBlock()->replaceAllUsesWith(getOrigBlock());
 }
 
 void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
@@ -1223,6 +1237,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     ConversionPatternRewriter &rewriter, Block *block,
     const TypeConverter *converter,
     TypeConverter::SignatureConversion &signatureConversion) {
+  // A block cannot be converted multiple times.
+  assert(!hasRewrite<BlockTypeConversionRewrite>(rewrites, block) &&
+         "block was already converted");
   OpBuilder::InsertionGuard g(rewriter);
 
   // If no arguments are being changed or added, there is nothing to do.
@@ -1308,7 +1325,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
   }
 
-  appendRewrite<BlockTypeConversionRewrite>(newBlock, block);
+  appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
 
   // Erase the old block. (It is just unlinked for now and will be erased during
   // cleanup.)

``````````

</details>


https://github.com/llvm/llvm-project/pull/117471


More information about the Mlir-commits mailing list