[Mlir-commits] [mlir] fff1830 - [mlir][bufferization] Run the simple dealloc canonicalization patterns as part of BufferDeallocationSimplification

Martin Erhart llvmlistbot at llvm.org
Mon Aug 28 01:04:22 PDT 2023


Author: Martin Erhart
Date: 2023-08-28T08:04:03Z
New Revision: fff183050adbd2c1f2e7fe78055a75aadb1694bb

URL: https://github.com/llvm/llvm-project/commit/fff183050adbd2c1f2e7fe78055a75aadb1694bb
DIFF: https://github.com/llvm/llvm-project/commit/fff183050adbd2c1f2e7fe78055a75aadb1694bb.diff

LOG: [mlir][bufferization] Run the simple dealloc canonicalization patterns as part of BufferDeallocationSimplification

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D158744

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
    mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
    mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
    mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
index 08d7126eca3153..450dfb37ddb2e1 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
@@ -58,6 +58,12 @@ FailureOr<Value> castOrReallocMemRefValue(OpBuilder &b, Value value,
 LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter,
                                        ToMemrefOp toMemref);
 
+/// Add the canonicalization patterns for bufferization.dealloc to the given
+/// pattern set to make them available to other passes (such as
+/// BufferDeallocationSimplification).
+void populateDeallocOpCanonicalizationPatterns(RewritePatternSet &patterns,
+                                               MLIRContext *context);
+
 } // namespace bufferization
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 83427eb7122afd..9a2a6d0f5c6d98 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -1018,10 +1018,15 @@ struct RemoveAllocDeallocPairWhenNoOtherUsers
 
 void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<DeallocRemoveDuplicateDeallocMemrefs,
-              DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
-              EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
-              RemoveAllocDeallocPairWhenNoOtherUsers>(context);
+  populateDeallocOpCanonicalizationPatterns(results, context);
+}
+
+void bufferization::populateDeallocOpCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  patterns.add<DeallocRemoveDuplicateDeallocMemrefs,
+               DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
+               EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
+               RemoveAllocDeallocPairWhenNoOtherUsers>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 1090cd668ca78c..5a6de372e2310d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -432,6 +432,7 @@ struct BufferDeallocationSimplificationPass
                  SplitDeallocWhenNotAliasingAnyOther,
                  RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
                                                                 aliasAnalysis);
+    populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
 
     if (failed(
             applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
index 473ce8734f0cd9..98eb038df30a3b 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
@@ -15,7 +15,6 @@ func.func @dealloc_deallocated_in_retained(%arg0: memref<2xi32>, %arg1: i1, %arg
 // CHECK-LABEL: func @dealloc_deallocated_in_retained
 //  CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>, [[ARG3:%.+]]: i1)
 //  CHECK-NEXT: arith.constant false
-//  CHECK-NEXT: bufferization.dealloc
 //  CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>)
 //  CHECK-NEXT: [[O1:%.+]] = arith.ori [[V1]], [[ARG1]]
 //  CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[ARG0]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]], [[ARG2]] : memref<2xi32>, memref<2xi32>)
@@ -23,7 +22,6 @@ func.func @dealloc_deallocated_in_retained(%arg0: memref<2xi32>, %arg1: i1, %arg
 // COM: retained memrefs since the list of memrefs to be deallocated becomes empty
 // COM: due to the pattern under test (and thus there is no memref the retain values
 // COM: could alias to)
-//  CHECK-NEXT: bufferization.dealloc
 // CHECK-NOT: if
 //  CHECK-NEXT: [[V3:%.+]] = arith.ori [[ARG3]], [[ARG1]]
 //  CHECK-NEXT: [[V4:%.+]] = arith.ori [[ARG3]], [[ARG1]]
@@ -50,7 +48,6 @@ func.func @dealloc_deallocated_in_retained_extract_base_memref(%arg0: memref<2xi
 //  CHECK-NEXT: arith.constant false
 //  CHECK-NEXT: [[BASE0:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG0]] :
 //  CHECK-NEXT: [[BASE1:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG2]] :
-//  CHECK-NEXT: bufferization.dealloc
 //  CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[BASE1]] : memref<i32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>)
 //  CHECK-NEXT: [[O1:%.+]] = arith.ori [[V1]], [[ARG1]]
 //  CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[BASE0]] : memref<i32>) if ([[ARG1]]) retain ([[ARG0]], [[ARG2]] : memref<2xi32>, memref<2xi32>)
@@ -58,7 +55,6 @@ func.func @dealloc_deallocated_in_retained_extract_base_memref(%arg0: memref<2xi
 // COM: retained memrefs since the list of memrefs to be deallocated becomes empty
 // COM: due to the pattern under test (and thus there is no memref the retain values
 // COM: could alias to)
-//  CHECK-NEXT: bufferization.dealloc
 // CHECK-NOT: if
 //  CHECK-NEXT: [[V3:%.+]] = arith.ori [[ARG3]], [[ARG1]]
 //  CHECK-NEXT: [[V4:%.+]] = arith.ori [[ARG3]], [[ARG1]]
@@ -66,11 +62,11 @@ func.func @dealloc_deallocated_in_retained_extract_base_memref(%arg0: memref<2xi
 
 // -----
 
-func.func @remove_retained_memrefs_guarateed_to_not_alias(%arg0: i1, %arg1: memref<2xi32>) -> (i1, i1) {
+func.func @remove_retained_memrefs_guarateed_to_not_alias(%arg0: i1, %arg1: memref<2xi32>) -> (i1, i1, memref<2xi32>) {
   %alloc = memref.alloc() : memref<2xi32>
   %alloc0 = memref.alloc() : memref<2xi32>
   %0:2 = bufferization.dealloc (%alloc : memref<2xi32>) if (%arg0) retain (%alloc0, %arg1 : memref<2xi32>, memref<2xi32>)
-  return %0#0, %0#1 : i1, i1
+  return %0#0, %0#1, %alloc : i1, i1, memref<2xi32>
 }
 
 // CHECK-LABEL: func @remove_retained_memrefs_guarateed_to_not_alias
@@ -79,7 +75,7 @@ func.func @remove_retained_memrefs_guarateed_to_not_alias(%arg0: i1, %arg1: memr
 //  CHECK-NEXT: [[ALLOC:%.+]] = memref.alloc(
 //  CHECK-NEXT: bufferization.dealloc ([[ALLOC]] : memref<2xi32>) if ([[ARG0]])
 //  CHECK-NOT: retain
-//  CHECK-NEXT: return [[FALSE]], [[FALSE]] :
+//  CHECK-NEXT: return [[FALSE]], [[FALSE]], [[ALLOC]] :
 
 // -----
 
@@ -104,7 +100,6 @@ func.func @dealloc_split_when_no_other_aliasing(%arg0: i1, %arg1: memref<2xi32>,
 //  CHECK-NEXT:   [[V1:%.+]] = bufferization.dealloc ([[ALLOC0]] : memref<2xi32>) if ([[ARG0]]) retain ([[V0]] : memref<2xi32>)
 //  CHECK-NEXT:   [[V2:%.+]]:2 = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG3]]) retain ([[ARG1]], [[V0]] : memref<2xi32>, memref<2xi32>)
 //  CHECK-NEXT:   [[V3:%.+]] = arith.ori [[V1]], [[V2]]#1
-//  CHECK-NEXT:   bufferization.dealloc
 //  CHECK-NEXT:   return [[V2]]#0, [[V3]] :
 
 // -----


        


More information about the Mlir-commits mailing list