[Mlir-commits] [mlir] [mlir][bufferization] Fix op dominance bug in rewrite pattern (PR #74159)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Dec 1 15:55:24 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Fixes a bug in `SplitDeallocWhenNotAliasingAnyOther`. This pattern used to generate invalid IR (op dominance error). We never noticed this bug in existing test cases because other patterns and/or foldings were applied afterwards and those rewrites "fixed up" the IR again. Also add additional comments to the implementation and simplify the code a bit.
Apart from the fixed dominance error, this change is NFC.
---
Full diff: https://github.com/llvm/llvm-project/pull/74159.diff
1 Files Affected:
- (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp (+31-24)
``````````diff
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 7bbdeab3ea1a870..42653517249d664 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -314,44 +314,51 @@ struct SplitDeallocWhenNotAliasingAnyOther
LogicalResult matchAndRewrite(DeallocOp deallocOp,
PatternRewriter &rewriter) const override {
+ Location loc = deallocOp.getLoc();
if (deallocOp.getMemrefs().size() <= 1)
return failure();
- SmallVector<Value> newMemrefs, newConditions, replacements;
- DenseSet<Operation *> exceptedUsers;
- replacements = deallocOp.getUpdatedConditions();
+ SmallVector<Value> remainingMemrefs, remainingConditions;
+ SmallVector<SmallVector<Value>> updatedConditions;
for (auto [memref, cond] :
llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
+ // Check if `memref` can split off into a separate bufferization.dealloc.
if (potentiallyAliasesMemref(aliasAnalysis, deallocOp.getMemrefs(),
memref, true)) {
- newMemrefs.push_back(memref);
- newConditions.push_back(cond);
+ // `memref` alias with other memrefs, do not split off.
+ remainingMemrefs.push_back(memref);
+ remainingConditions.push_back(cond);
continue;
}
- auto newDeallocOp = rewriter.create<DeallocOp>(
- deallocOp.getLoc(), memref, cond, deallocOp.getRetained());
- replacements = SmallVector<Value>(llvm::map_range(
- llvm::zip(replacements, newDeallocOp.getUpdatedConditions()),
- [&](auto replAndNew) -> Value {
- auto orOp = rewriter.create<arith::OrIOp>(deallocOp.getLoc(),
- std::get<0>(replAndNew),
- std::get<1>(replAndNew));
- exceptedUsers.insert(orOp);
- return orOp.getResult();
- }));
+ // Create new bufferization.dealloc op for `memref`.
+ auto newDeallocOp = rewriter.create<DeallocOp>(loc, memref, cond,
+ deallocOp.getRetained());
+ updatedConditions.push_back(
+ llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions())));
}
- if (newMemrefs.size() == deallocOp.getMemrefs().size())
+ // Fail if no memref was split off.
+ if (remainingMemrefs.size() == deallocOp.getMemrefs().size())
return failure();
- rewriter.replaceUsesWithIf(deallocOp.getUpdatedConditions(), replacements,
- [&](OpOperand &operand) {
- return !exceptedUsers.contains(
- operand.getOwner());
- });
- return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
- rewriter);
+ // Create bufferization.dealloc op for all remaining memrefs.
+ auto newDeallocOp = rewriter.create<DeallocOp>(
+ loc, remainingMemrefs, remainingConditions, deallocOp.getRetained());
+
+ // Bit-or all conditions.
+ SmallVector<Value> replacements =
+ llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions()));
+ for (auto additionalConditions : updatedConditions) {
+ assert(replacements.size() == additionalConditions.size() &&
+ "expected same number of updated conditions");
+ for (int64_t i = 0, e = replacements.size(); i < e; ++i) {
+ replacements[i] = rewriter.create<arith::OrIOp>(
+ loc, replacements[i], additionalConditions[i]);
+ }
+ }
+ rewriter.replaceOp(deallocOp, replacements);
+ return success();
}
private:
``````````
</details>
https://github.com/llvm/llvm-project/pull/74159
More information about the Mlir-commits
mailing list