[Mlir-commits] [mlir] 3dae97c - [mlir][bufferization] Fix op dominance bug in rewrite pattern (#74159)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 4 18:46:35 PST 2023


Author: Matthias Springer
Date: 2023-12-05T11:46:30+09:00
New Revision: 3dae97cc011ca097bd457bbfa5855da86290f631

URL: https://github.com/llvm/llvm-project/commit/3dae97cc011ca097bd457bbfa5855da86290f631
DIFF: https://github.com/llvm/llvm-project/commit/3dae97cc011ca097bd457bbfa5855da86290f631.diff

LOG: [mlir][bufferization] Fix op dominance bug in rewrite pattern (#74159)

Fixes a bug in `SplitDeallocWhenNotAliasingAnyOther`. This pattern used
to generate invalid IR (op dominance error). We never noticed this bug
in existing test cases because other patterns and/or foldings were
applied afterwards and those rewrites "fixed up" the IR again. (The bug
is visible when running `mlir-opt -debug`.) Also add additional comments
to the implementation and simplify the code a bit.

Apart from the fixed dominance error, this change is NFC. Without this
change, buffer deallocation tests will fail when running with #74270.

Added: 
    

Modified: 
    mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 7bbdeab3ea1a8..42653517249d6 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -314,44 +314,51 @@ struct SplitDeallocWhenNotAliasingAnyOther
 
   LogicalResult matchAndRewrite(DeallocOp deallocOp,
                                 PatternRewriter &rewriter) const override {
+    Location loc = deallocOp.getLoc();
     if (deallocOp.getMemrefs().size() <= 1)
       return failure();
 
-    SmallVector<Value> newMemrefs, newConditions, replacements;
-    DenseSet<Operation *> exceptedUsers;
-    replacements = deallocOp.getUpdatedConditions();
+    SmallVector<Value> remainingMemrefs, remainingConditions;
+    SmallVector<SmallVector<Value>> updatedConditions;
     for (auto [memref, cond] :
          llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
+      // Check if `memref` can split off into a separate bufferization.dealloc.
       if (potentiallyAliasesMemref(aliasAnalysis, deallocOp.getMemrefs(),
                                    memref, true)) {
-        newMemrefs.push_back(memref);
-        newConditions.push_back(cond);
+        // `memref` alias with other memrefs, do not split off.
+        remainingMemrefs.push_back(memref);
+        remainingConditions.push_back(cond);
         continue;
       }
 
-      auto newDeallocOp = rewriter.create<DeallocOp>(
-          deallocOp.getLoc(), memref, cond, deallocOp.getRetained());
-      replacements = SmallVector<Value>(llvm::map_range(
-          llvm::zip(replacements, newDeallocOp.getUpdatedConditions()),
-          [&](auto replAndNew) -> Value {
-            auto orOp = rewriter.create<arith::OrIOp>(deallocOp.getLoc(),
-                                                      std::get<0>(replAndNew),
-                                                      std::get<1>(replAndNew));
-            exceptedUsers.insert(orOp);
-            return orOp.getResult();
-          }));
+      // Create new bufferization.dealloc op for `memref`.
+      auto newDeallocOp = rewriter.create<DeallocOp>(loc, memref, cond,
+                                                     deallocOp.getRetained());
+      updatedConditions.push_back(
+          llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions())));
     }
 
-    if (newMemrefs.size() == deallocOp.getMemrefs().size())
+    // Fail if no memref was split off.
+    if (remainingMemrefs.size() == deallocOp.getMemrefs().size())
       return failure();
 
-    rewriter.replaceUsesWithIf(deallocOp.getUpdatedConditions(), replacements,
-                               [&](OpOperand &operand) {
-                                 return !exceptedUsers.contains(
-                                     operand.getOwner());
-                               });
-    return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
-                                  rewriter);
+    // Create bufferization.dealloc op for all remaining memrefs.
+    auto newDeallocOp = rewriter.create<DeallocOp>(
+        loc, remainingMemrefs, remainingConditions, deallocOp.getRetained());
+
+    // Bit-or all conditions.
+    SmallVector<Value> replacements =
+        llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions()));
+    for (auto additionalConditions : updatedConditions) {
+      assert(replacements.size() == additionalConditions.size() &&
+             "expected same number of updated conditions");
+      for (int64_t i = 0, e = replacements.size(); i < e; ++i) {
+        replacements[i] = rewriter.create<arith::OrIOp>(
+            loc, replacements[i], additionalConditions[i]);
+      }
+    }
+    rewriter.replaceOp(deallocOp, replacements);
+    return success();
   }
 
 private:


        


More information about the Mlir-commits mailing list