[Mlir-commits] [mlir] 5c7d97b - [mlir][bufferization] Canonicalizer to skip extract_strided_metadata when operand is already a base memref.
Martin Erhart
llvmlistbot at llvm.org
Thu Aug 10 00:11:47 PDT 2023
Author: Martin Erhart
Date: 2023-08-10T07:11:25Z
New Revision: 5c7d97be4acb1ec7dca34555873bbc8f562dfdfa
URL: https://github.com/llvm/llvm-project/commit/5c7d97be4acb1ec7dca34555873bbc8f562dfdfa
DIFF: https://github.com/llvm/llvm-project/commit/5c7d97be4acb1ec7dca34555873bbc8f562dfdfa.diff
LOG: [mlir][bufferization] Canonicalizer to skip extract_strided_metadata when operand is already a base memref.
The `extract_strided_metadata` will be heavily used by the new buffer deallocation pass to get the base memref and pass it to the deallocation operation. This commit factors out some simplification logic of the pass into a canonicalization pattern.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D157255
Added:
Modified:
mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
mlir/test/Dialect/Bufferization/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 8be72fb98644f1..d5237164bd0e83 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -757,10 +757,11 @@ LogicalResult DeallocOp::verify() {
}
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
- ArrayRef<Value> memrefs,
- ArrayRef<Value> conditions,
+ ValueRange memrefs,
+ ValueRange conditions,
PatternRewriter &rewriter) {
- if (deallocOp.getMemrefs() == memrefs)
+ if (deallocOp.getMemrefs() == memrefs &&
+ deallocOp.getConditions() == conditions)
return failure();
rewriter.updateRootInPlace(deallocOp, [&]() {
@@ -972,6 +973,49 @@ struct EraseAlwaysFalseDealloc : public OpRewritePattern<DeallocOp> {
}
};
+/// The `memref.extract_strided_metadata` is often inserted to get the base
+/// memref if the operand is not already guaranteed to be the result of a memref
+/// allocation operation. This canonicalization pattern removes this extraction
+/// operation if the operand is now produced by an allocation operation (e.g.,
+/// due to other canonicalizations simplifying the IR).
+///
+/// Example:
+/// ```mlir
+/// %alloc = memref.alloc() : memref<2xi32>
+/// %base_memref, %offset, %size, %stride = memref.extract_strided_metadata
+/// %alloc : memref<2xi32> -> memref<i32>, index, index, index
+/// bufferization.dealloc (%base_memref : memref<i32>) if (%cond)
+/// ```
+/// is canonicalized to
+/// ```mlir
+/// %alloc = memref.alloc() : memref<2xi32>
+/// bufferization.dealloc (%alloc : memref<2xi32>) if (%cond)
+/// ```
+struct SkipExtractMetadataOfAlloc : public OpRewritePattern<DeallocOp> {
+ using OpRewritePattern<DeallocOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(DeallocOp deallocOp,
+ PatternRewriter &rewriter) const override {
+ SmallVector<Value> newMemrefs(
+ llvm::map_range(deallocOp.getMemrefs(), [&](Value memref) {
+ auto extractStridedOp =
+ memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
+ if (!extractStridedOp)
+ return memref;
+ Value allocMemref = extractStridedOp.getOperand();
+ auto allocOp = allocMemref.getDefiningOp<MemoryEffectOpInterface>();
+ if (!allocOp)
+ return memref;
+ if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(allocMemref))
+ return allocMemref;
+ return memref;
+ }));
+
+ return updateDeallocIfChanged(deallocOp, newMemrefs,
+ deallocOp.getConditions(), rewriter);
+ }
+};
+
} // anonymous namespace
void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -979,7 +1023,7 @@ void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<DeallocRemoveDuplicateDeallocMemrefs,
DeallocRemoveDuplicateRetainedMemrefs,
DeallocRemoveDeallocMemrefsContainedInRetained, EraseEmptyDealloc,
- EraseAlwaysFalseDealloc>(context);
+ EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index 6a4edf1e6335f1..af222899e5bbd5 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -323,3 +323,20 @@ func.func @dealloc_always_false_condition(%arg0: memref<2xi32>, %arg1: memref<2x
// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: memref<2xi32>, [[ARG2:%.+]]: i1)
// CHECK-NEXT: bufferization.dealloc ([[ARG1]] : {{.*}}) if ([[ARG2]])
// CHECK-NEXT: return
+
+// -----
+
+func.func @dealloc_base_memref_extract_of_alloc(%arg0: memref<2xi32>, %arg1: i1, %arg2: i1, %arg3: memref<2xi32>) {
+ %alloc = memref.alloc() : memref<2xi32>
+ %base0, %size0, %stride0, %offset0 = memref.extract_strided_metadata %alloc : memref<2xi32> -> memref<i32>, index, index, index
+ %base1, %size1, %stride1, %offset1 = memref.extract_strided_metadata %arg3 : memref<2xi32> -> memref<i32>, index, index, index
+ bufferization.dealloc (%base0, %arg0, %base1 : memref<i32>, memref<2xi32>, memref<i32>) if (%arg1, %arg2, %arg2)
+ return
+}
+
+// CHECK-LABEL: func @dealloc_base_memref_extract_of_alloc
+// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: i1, [[ARG3:%.+]]: memref<2xi32>)
+// CHECK-NEXT: [[ALLOC:%.+]] = memref.alloc() : memref<2xi32>
+// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG3]] :
+// CHECK-NEXT: bufferization.dealloc ([[ALLOC]], [[ARG0]], [[BASE]] : memref<2xi32>, memref<2xi32>, memref<i32>) if ([[ARG1]], [[ARG2]], [[ARG2]])
+// CHECK-NEXT: return
More information about the Mlir-commits
mailing list