[Mlir-commits] [mlir] fff1830 - [mlir][bufferization] Run the simple dealloc canonicalization patterns as part of BufferDeallocationSimplification
Martin Erhart
llvmlistbot at llvm.org
Mon Aug 28 01:04:22 PDT 2023
Author: Martin Erhart
Date: 2023-08-28T08:04:03Z
New Revision: fff183050adbd2c1f2e7fe78055a75aadb1694bb
URL: https://github.com/llvm/llvm-project/commit/fff183050adbd2c1f2e7fe78055a75aadb1694bb
DIFF: https://github.com/llvm/llvm-project/commit/fff183050adbd2c1f2e7fe78055a75aadb1694bb.diff
LOG: [mlir][bufferization] Run the simple dealloc canonicalization patterns as part of BufferDeallocationSimplification
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D158744
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
index 08d7126eca3153..450dfb37ddb2e1 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
@@ -58,6 +58,12 @@ FailureOr<Value> castOrReallocMemRefValue(OpBuilder &b, Value value,
LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter,
ToMemrefOp toMemref);
+/// Add the canonicalization patterns for bufferization.dealloc to the given
+/// pattern set to make them available to other passes (such as
+/// BufferDeallocationSimplification).
+void populateDeallocOpCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context);
+
} // namespace bufferization
} // namespace mlir
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 83427eb7122afd..9a2a6d0f5c6d98 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -1018,10 +1018,15 @@ struct RemoveAllocDeallocPairWhenNoOtherUsers
void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<DeallocRemoveDuplicateDeallocMemrefs,
- DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
- EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
- RemoveAllocDeallocPairWhenNoOtherUsers>(context);
+ populateDeallocOpCanonicalizationPatterns(results, context);
+}
+
+void bufferization::populateDeallocOpCanonicalizationPatterns(
+ RewritePatternSet &patterns, MLIRContext *context) {
+ patterns.add<DeallocRemoveDuplicateDeallocMemrefs,
+ DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
+ EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
+ RemoveAllocDeallocPairWhenNoOtherUsers>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 1090cd668ca78c..5a6de372e2310d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -432,6 +432,7 @@ struct BufferDeallocationSimplificationPass
SplitDeallocWhenNotAliasingAnyOther,
RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
aliasAnalysis);
+ populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
index 473ce8734f0cd9..98eb038df30a3b 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
@@ -15,7 +15,6 @@ func.func @dealloc_deallocated_in_retained(%arg0: memref<2xi32>, %arg1: i1, %arg
// CHECK-LABEL: func @dealloc_deallocated_in_retained
// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>, [[ARG3:%.+]]: i1)
// CHECK-NEXT: arith.constant false
-// CHECK-NEXT: bufferization.dealloc
// CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>)
// CHECK-NEXT: [[O1:%.+]] = arith.ori [[V1]], [[ARG1]]
// CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[ARG0]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]], [[ARG2]] : memref<2xi32>, memref<2xi32>)
@@ -23,7 +22,6 @@ func.func @dealloc_deallocated_in_retained(%arg0: memref<2xi32>, %arg1: i1, %arg
// COM: retained memrefs since the list of memrefs to be deallocated becomes empty
// COM: due to the pattern under test (and thus there is no memref the retain values
// COM: could alias to)
-// CHECK-NEXT: bufferization.dealloc
// CHECK-NOT: if
// CHECK-NEXT: [[V3:%.+]] = arith.ori [[ARG3]], [[ARG1]]
// CHECK-NEXT: [[V4:%.+]] = arith.ori [[ARG3]], [[ARG1]]
@@ -50,7 +48,6 @@ func.func @dealloc_deallocated_in_retained_extract_base_memref(%arg0: memref<2xi
// CHECK-NEXT: arith.constant false
// CHECK-NEXT: [[BASE0:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG0]] :
// CHECK-NEXT: [[BASE1:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG2]] :
-// CHECK-NEXT: bufferization.dealloc
// CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[BASE1]] : memref<i32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>)
// CHECK-NEXT: [[O1:%.+]] = arith.ori [[V1]], [[ARG1]]
// CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[BASE0]] : memref<i32>) if ([[ARG1]]) retain ([[ARG0]], [[ARG2]] : memref<2xi32>, memref<2xi32>)
@@ -58,7 +55,6 @@ func.func @dealloc_deallocated_in_retained_extract_base_memref(%arg0: memref<2xi
// COM: retained memrefs since the list of memrefs to be deallocated becomes empty
// COM: due to the pattern under test (and thus there is no memref the retain values
// COM: could alias to)
-// CHECK-NEXT: bufferization.dealloc
// CHECK-NOT: if
// CHECK-NEXT: [[V3:%.+]] = arith.ori [[ARG3]], [[ARG1]]
// CHECK-NEXT: [[V4:%.+]] = arith.ori [[ARG3]], [[ARG1]]
@@ -66,11 +62,11 @@ func.func @dealloc_deallocated_in_retained_extract_base_memref(%arg0: memref<2xi
// -----
-func.func @remove_retained_memrefs_guarateed_to_not_alias(%arg0: i1, %arg1: memref<2xi32>) -> (i1, i1) {
+func.func @remove_retained_memrefs_guarateed_to_not_alias(%arg0: i1, %arg1: memref<2xi32>) -> (i1, i1, memref<2xi32>) {
%alloc = memref.alloc() : memref<2xi32>
%alloc0 = memref.alloc() : memref<2xi32>
%0:2 = bufferization.dealloc (%alloc : memref<2xi32>) if (%arg0) retain (%alloc0, %arg1 : memref<2xi32>, memref<2xi32>)
- return %0#0, %0#1 : i1, i1
+ return %0#0, %0#1, %alloc : i1, i1, memref<2xi32>
}
// CHECK-LABEL: func @remove_retained_memrefs_guarateed_to_not_alias
@@ -79,7 +75,7 @@ func.func @remove_retained_memrefs_guarateed_to_not_alias(%arg0: i1, %arg1: memr
// CHECK-NEXT: [[ALLOC:%.+]] = memref.alloc(
// CHECK-NEXT: bufferization.dealloc ([[ALLOC]] : memref<2xi32>) if ([[ARG0]])
// CHECK-NOT: retain
-// CHECK-NEXT: return [[FALSE]], [[FALSE]] :
+// CHECK-NEXT: return [[FALSE]], [[FALSE]], [[ALLOC]] :
// -----
@@ -104,7 +100,6 @@ func.func @dealloc_split_when_no_other_aliasing(%arg0: i1, %arg1: memref<2xi32>,
// CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[ALLOC0]] : memref<2xi32>) if ([[ARG0]]) retain ([[V0]] : memref<2xi32>)
// CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG3]]) retain ([[ARG1]], [[V0]] : memref<2xi32>, memref<2xi32>)
// CHECK-NEXT: [[V3:%.+]] = arith.ori [[V1]], [[V2]]#1
-// CHECK-NEXT: bufferization.dealloc
// CHECK-NEXT: return [[V2]]#0, [[V3]] :
// -----
More information about the Mlir-commits
mailing list