[Mlir-commits] [mlir] 3dae97c - [mlir][bufferization] Fix op dominance bug in rewrite pattern (#74159)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 4 18:46:35 PST 2023
Author: Matthias Springer
Date: 2023-12-05T11:46:30+09:00
New Revision: 3dae97cc011ca097bd457bbfa5855da86290f631
URL: https://github.com/llvm/llvm-project/commit/3dae97cc011ca097bd457bbfa5855da86290f631
DIFF: https://github.com/llvm/llvm-project/commit/3dae97cc011ca097bd457bbfa5855da86290f631.diff
LOG: [mlir][bufferization] Fix op dominance bug in rewrite pattern (#74159)
Fixes a bug in `SplitDeallocWhenNotAliasingAnyOther`. This pattern used
to generate invalid IR (op dominance error). We never noticed this bug
in existing test cases because other patterns and/or foldings were
applied afterwards and those rewrites "fixed up" the IR again. (The bug
is visible when running `mlir-opt -debug`.) Also add additional comments
to the implementation and simplify the code a bit.
Apart from the fixed dominance error, this change is NFC. Without this
change, buffer deallocation tests will fail when running with #74270.
Added:
Modified:
mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 7bbdeab3ea1a8..42653517249d6 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -314,44 +314,51 @@ struct SplitDeallocWhenNotAliasingAnyOther
LogicalResult matchAndRewrite(DeallocOp deallocOp,
PatternRewriter &rewriter) const override {
+ Location loc = deallocOp.getLoc();
if (deallocOp.getMemrefs().size() <= 1)
return failure();
- SmallVector<Value> newMemrefs, newConditions, replacements;
- DenseSet<Operation *> exceptedUsers;
- replacements = deallocOp.getUpdatedConditions();
+ SmallVector<Value> remainingMemrefs, remainingConditions;
+ SmallVector<SmallVector<Value>> updatedConditions;
for (auto [memref, cond] :
llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
+ // Check if `memref` can split off into a separate bufferization.dealloc.
if (potentiallyAliasesMemref(aliasAnalysis, deallocOp.getMemrefs(),
memref, true)) {
- newMemrefs.push_back(memref);
- newConditions.push_back(cond);
+ // `memref` alias with other memrefs, do not split off.
+ remainingMemrefs.push_back(memref);
+ remainingConditions.push_back(cond);
continue;
}
- auto newDeallocOp = rewriter.create<DeallocOp>(
- deallocOp.getLoc(), memref, cond, deallocOp.getRetained());
- replacements = SmallVector<Value>(llvm::map_range(
- llvm::zip(replacements, newDeallocOp.getUpdatedConditions()),
- [&](auto replAndNew) -> Value {
- auto orOp = rewriter.create<arith::OrIOp>(deallocOp.getLoc(),
- std::get<0>(replAndNew),
- std::get<1>(replAndNew));
- exceptedUsers.insert(orOp);
- return orOp.getResult();
- }));
+ // Create new bufferization.dealloc op for `memref`.
+ auto newDeallocOp = rewriter.create<DeallocOp>(loc, memref, cond,
+ deallocOp.getRetained());
+ updatedConditions.push_back(
+ llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions())));
}
- if (newMemrefs.size() == deallocOp.getMemrefs().size())
+ // Fail if no memref was split off.
+ if (remainingMemrefs.size() == deallocOp.getMemrefs().size())
return failure();
- rewriter.replaceUsesWithIf(deallocOp.getUpdatedConditions(), replacements,
- [&](OpOperand &operand) {
- return !exceptedUsers.contains(
- operand.getOwner());
- });
- return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
- rewriter);
+ // Create bufferization.dealloc op for all remaining memrefs.
+ auto newDeallocOp = rewriter.create<DeallocOp>(
+ loc, remainingMemrefs, remainingConditions, deallocOp.getRetained());
+
+ // Bit-or all conditions.
+ SmallVector<Value> replacements =
+ llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions()));
+ for (auto additionalConditions : updatedConditions) {
+ assert(replacements.size() == additionalConditions.size() &&
+ "expected same number of updated conditions");
+ for (int64_t i = 0, e = replacements.size(); i < e; ++i) {
+ replacements[i] = rewriter.create<arith::OrIOp>(
+ loc, replacements[i], additionalConditions[i]);
+ }
+ }
+ rewriter.replaceOp(deallocOp, replacements);
+ return success();
}
private:
More information about the Mlir-commits
mailing list