[Mlir-commits] [mlir] [mlir][Transforms] Fix use-after-free when accessing replaced block args (PR #83646)

Matthias Springer llvmlistbot at llvm.org
Fri Mar 1 18:38:12 PST 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/83646

This commit fixes a bug in a dialect conversion. Currently, when a block is replaced via a signature conversion, the block is erased during the "commit" phase. This is problematic because the block arguments may still be referenced internal data structures of the dialect conversion (`mapping`). Blocks should be treated same as ops: they should be erased during the "cleanup" phase.

Note: The test case fails without this fix when running with ASAN, but may pass when running without ASAN.

>From d3248a821240fd1539cb332d55377864323a9268 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Sat, 2 Mar 2024 02:37:05 +0000
Subject: [PATCH] [mlir][Transforms] Fix use-after-free when accessing replaced
 block args

This commit fixes a bug in a dialect conversion. Currently, when a block is replaced via a signature conversion, the block is erased during the "commit" phase. This is problematic because the block arguments may still be referenced internal data structures of the dialect conversion (`mapping`). Blocks should be treated same as ops: they should be erased during the "cleanup" phase.

Note: The test case fails without this fix when running with ASAN, but may pass when running without ASAN.
---
 .../Transforms/Utils/DialectConversion.cpp    | 30 ++++++++++++-------
 mlir/test/Transforms/test-legalizer.mlir      | 12 ++++++++
 mlir/test/lib/Dialect/Test/TestOps.td         |  4 +--
 3 files changed, 34 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 26899301eb742e..7846f1ab56811a 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -207,13 +207,13 @@ class IRRewrite {
   /// Roll back the rewrite. Operations may be erased during rollback.
   virtual void rollback() = 0;
 
-  /// Commit the rewrite. Operations may be unlinked from their blocks during
-  /// the commit phase, but they must not be erased yet. This is because
-  /// internal dialect conversion state (such as `mapping`) may still be using
-  /// them. Operations must be erased during cleanup.
+  /// Commit the rewrite. Operations/blocks may be unlinked during the commit
+  /// phase, but they must not be erased yet. This is because internal dialect
+  /// conversion state (such as `mapping`) may still be using them. Operations/
+  /// blocks must be erased during cleanup.
   virtual void commit() {}
 
-  /// Cleanup operations. Cleanup is called after commit.
+  /// Cleanup operations/blocks. Cleanup is called after commit.
   virtual void cleanup() {}
 
   Kind getKind() const { return kind; }
@@ -282,9 +282,9 @@ class CreateBlockRewrite : public BlockRewrite {
 };
 
 /// Erasure of a block. Block erasures are partially reflected in the IR. Erased
-/// blocks are immediately unlinked, but only erased when the rewrite is
-/// committed. This makes it easier to rollback a block erasure: the block is
-/// simply inserted into its original location.
+/// blocks are immediately unlinked, but only erased during cleanup. This makes
+/// it easier to rollback a block erasure: the block is simply inserted into its
+/// original location.
 class EraseBlockRewrite : public BlockRewrite {
 public:
   EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
@@ -297,7 +297,8 @@ class EraseBlockRewrite : public BlockRewrite {
   }
 
   ~EraseBlockRewrite() override {
-    assert(!block && "rewrite was neither rolled back nor committed");
+    assert(!block &&
+           "rewrite was neither rolled back nor committed/cleaned up");
   }
 
   void rollback() override {
@@ -312,7 +313,7 @@ class EraseBlockRewrite : public BlockRewrite {
     block = nullptr;
   }
 
-  void commit() override {
+  void cleanup() override {
     // Erase the block.
     assert(block && "expected block");
     assert(block->empty() && "expected empty block");
@@ -440,6 +441,8 @@ class BlockTypeConversionRewrite : public BlockRewrite {
 
   void commit() override;
 
+  void cleanup() override;
+
   void rollback() override;
 
 private:
@@ -986,7 +989,9 @@ void BlockTypeConversionRewrite::commit() {
           rewriterImpl.mapping.lookupOrDefault(castValue, origArg.getType()));
     }
   }
+}
 
+void BlockTypeConversionRewrite::cleanup() {
   assert(origBlock->empty() && "expected empty block");
   origBlock->dropAllDefinedValueUses();
   delete origBlock;
@@ -1484,6 +1489,11 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
     Block *block, Region *previous, Region::iterator previousIt) {
   assert(!wasOpReplaced(block->getParentOp()) &&
          "attempting to insert into a region within a replaced/erased op");
+  LLVM_DEBUG({
+    logger.startLine() << "** Insert Block into : '"
+                       << block->getParentOp()->getName() << "'("
+                       << block->getParentOp() << ")\n";
+  });
 
   if (!previous) {
     // This is a newly created block.
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 62d776cd7573ee..ccdc9fe78ea0d3 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -346,3 +346,15 @@ func.func @test_properties_rollback() {
       {modify_inplace}
   "test.return"() : () -> ()
 }
+
+// -----
+
+//      CHECK: func.func @use_of_replaced_bbarg(
+// CHECK-SAME:     %[[arg0:.*]]: f64)
+//      CHECK:   "test.valid"(%[[arg0]])
+func.func @use_of_replaced_bbarg(%arg0: i64) {
+  %0 = "test.op_with_region_fold"(%arg0) ({
+    "foo.op_with_region_terminator"() : () -> ()
+  }) : (i64) -> (i64)
+  "test.invalid"(%0) : (i64) -> ()
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 91ce0af9cd7fd5..dfd2f21a5ea249 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1270,8 +1270,8 @@ def TestOpWithRegionFoldNoMemoryEffect : TEST_Op<
 
 // Op for testing folding of outer op with inner ops.
 def TestOpWithRegionFold : TEST_Op<"op_with_region_fold"> {
-  let arguments = (ins I32:$operand);
-  let results = (outs I32);
+  let arguments = (ins AnyType:$operand);
+  let results = (outs AnyType);
   let regions = (region SizedRegion<1>:$region);
   let hasFolder = 1;
 }



More information about the Mlir-commits mailing list