[Mlir-commits] [mlir] [mlir][bufferization] Fix op dominance bug in rewrite pattern (PR #74159)

Matthias Springer llvmlistbot at llvm.org
Fri Dec 1 15:54:58 PST 2023


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/74159

Fixes a bug in `SplitDeallocWhenNotAliasingAnyOther`. This pattern used to generate invalid IR (op dominance error). We never noticed this bug in existing test cases because other patterns and/or foldings were applied afterwards and those rewrites "fixed up" the IR again. Also add additional comments to the implementation and simplify the code a bit.

Apart from the fixed dominance error, this change is NFC.


>From 1d8c0f009cc41dd785addcf3262107e55974ac9d Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sat, 2 Dec 2023 00:53:23 +0100
Subject: [PATCH] [mlir][bufferization] Fix op dominance bug in rewrite pattern

Fixes a bug in `SplitDeallocWhenNotAliasingAnyOther`. This pattern used to generate invalid IR (op dominance error). We never noticed this bug in existing test cases because other patterns and/or foldings were applied afterwards and those rewrites "fixed up" the IR again. Also add additional comments to the implementation and simplify the code a bit.
---
 .../BufferDeallocationSimplification.cpp      | 55 +++++++++++--------
 1 file changed, 31 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 7bbdeab3ea1a870..42653517249d664 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -314,44 +314,51 @@ struct SplitDeallocWhenNotAliasingAnyOther
 
   LogicalResult matchAndRewrite(DeallocOp deallocOp,
                                 PatternRewriter &rewriter) const override {
+    Location loc = deallocOp.getLoc();
     if (deallocOp.getMemrefs().size() <= 1)
       return failure();
 
-    SmallVector<Value> newMemrefs, newConditions, replacements;
-    DenseSet<Operation *> exceptedUsers;
-    replacements = deallocOp.getUpdatedConditions();
+    SmallVector<Value> remainingMemrefs, remainingConditions;
+    SmallVector<SmallVector<Value>> updatedConditions;
     for (auto [memref, cond] :
          llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
+      // Check if `memref` can split off into a separate bufferization.dealloc.
       if (potentiallyAliasesMemref(aliasAnalysis, deallocOp.getMemrefs(),
                                    memref, true)) {
-        newMemrefs.push_back(memref);
-        newConditions.push_back(cond);
+        // `memref` alias with other memrefs, do not split off.
+        remainingMemrefs.push_back(memref);
+        remainingConditions.push_back(cond);
         continue;
       }
 
-      auto newDeallocOp = rewriter.create<DeallocOp>(
-          deallocOp.getLoc(), memref, cond, deallocOp.getRetained());
-      replacements = SmallVector<Value>(llvm::map_range(
-          llvm::zip(replacements, newDeallocOp.getUpdatedConditions()),
-          [&](auto replAndNew) -> Value {
-            auto orOp = rewriter.create<arith::OrIOp>(deallocOp.getLoc(),
-                                                      std::get<0>(replAndNew),
-                                                      std::get<1>(replAndNew));
-            exceptedUsers.insert(orOp);
-            return orOp.getResult();
-          }));
+      // Create new bufferization.dealloc op for `memref`.
+      auto newDeallocOp = rewriter.create<DeallocOp>(loc, memref, cond,
+                                                     deallocOp.getRetained());
+      updatedConditions.push_back(
+          llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions())));
     }
 
-    if (newMemrefs.size() == deallocOp.getMemrefs().size())
+    // Fail if no memref was split off.
+    if (remainingMemrefs.size() == deallocOp.getMemrefs().size())
       return failure();
 
-    rewriter.replaceUsesWithIf(deallocOp.getUpdatedConditions(), replacements,
-                               [&](OpOperand &operand) {
-                                 return !exceptedUsers.contains(
-                                     operand.getOwner());
-                               });
-    return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
-                                  rewriter);
+    // Create bufferization.dealloc op for all remaining memrefs.
+    auto newDeallocOp = rewriter.create<DeallocOp>(
+        loc, remainingMemrefs, remainingConditions, deallocOp.getRetained());
+
+    // Bit-or all conditions.
+    SmallVector<Value> replacements =
+        llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions()));
+    for (auto additionalConditions : updatedConditions) {
+      assert(replacements.size() == additionalConditions.size() &&
+             "expected same number of updated conditions");
+      for (int64_t i = 0, e = replacements.size(); i < e; ++i) {
+        replacements[i] = rewriter.create<arith::OrIOp>(
+            loc, replacements[i], additionalConditions[i]);
+      }
+    }
+    rewriter.replaceOp(deallocOp, replacements);
+    return success();
   }
 
 private:



More information about the Mlir-commits mailing list