[Mlir-commits] [mlir] f22d381 - [mlir] Canonicalize AllocOp's with only store and dealloc uses

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 23 23:55:08 PDT 2021


Author: Butygin
Date: 2021-04-24T09:51:00+03:00
New Revision: f22d3813850f9e87c5204df6844a93b8c5db7730

URL: https://github.com/llvm/llvm-project/commit/f22d3813850f9e87c5204df6844a93b8c5db7730
DIFF: https://github.com/llvm/llvm-project/commit/f22d3813850f9e87c5204df6844a93b8c5db7730.diff

LOG: [mlir] Canonicalize AllocOp's with only store and dealloc uses

Differential Revision: https://reviews.llvm.org/D100268

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/Affine/canonicalize.mlir
    mlir/test/Dialect/Tensor/canonicalize.mlir
    mlir/test/Transforms/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index fe0fd7d0ff36..7b341b1940cf 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -378,7 +378,6 @@ def MemRef_DeallocOp : MemRef_Op<"dealloc", [MemRefsNormalizable]> {
 
   let arguments = (ins Arg<AnyMemRef, "", [MemFree]>:$memref);
 
-  let hasCanonicalizer = 1;
   let hasFolder = 1;
   let assemblyFormat = "$memref attr-dict `:` type($memref)";
 }

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 7e55f4c2877b..1ac00022e232 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -195,30 +195,36 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
   }
 };
 
-/// Fold alloc operations with no uses. Alloc has side effects on the heap,
-/// but can still be deleted if it has zero uses.
-struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> {
-  using OpRewritePattern<AllocOp>::OpRewritePattern;
+/// Fold alloc operations with no users or only store and dealloc uses.
+template <typename T>
+struct SimplifyDeadAlloc : public OpRewritePattern<T> {
+  using OpRewritePattern<T>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(AllocOp alloc,
+  LogicalResult matchAndRewrite(T alloc,
                                 PatternRewriter &rewriter) const override {
-    if (alloc.use_empty()) {
-      rewriter.eraseOp(alloc);
-      return success();
-    }
-    return failure();
+    if (llvm::any_of(alloc->getUsers(), [](Operation *op) {
+          return !isa<StoreOp, DeallocOp>(op);
+        }))
+      return failure();
+
+    for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
+      rewriter.eraseOp(user);
+
+    rewriter.eraseOp(alloc);
+    return success();
   }
 };
 } // end anonymous namespace.
 
 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                           MLIRContext *context) {
-  results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc>(context);
+  results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
 }
 
 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
-  results.add<SimplifyAllocConst<AllocaOp>>(context);
+  results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -537,30 +543,6 @@ OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 // DeallocOp
 //===----------------------------------------------------------------------===//
-namespace {
-/// Fold Dealloc operations that are deallocating an AllocOp that is only used
-/// by other Dealloc operations.
-struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> {
-  using OpRewritePattern<DeallocOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(DeallocOp dealloc,
-                                PatternRewriter &rewriter) const override {
-    // Check that the memref operand's defining operation is an AllocOp.
-    Value memref = dealloc.memref();
-    if (!isa_and_nonnull<AllocOp>(memref.getDefiningOp()))
-      return failure();
-
-    // Check that all of the uses of the AllocOp are other DeallocOps.
-    for (auto *user : memref.getUsers())
-      if (!isa<DeallocOp>(user))
-        return failure();
-
-    // Erase the dealloc operation.
-    rewriter.eraseOp(dealloc);
-    return success();
-  }
-};
-} // end anonymous namespace.
 
 static LogicalResult verify(DeallocOp op) {
   if (!op.memref().getType().isa<MemRefType>())
@@ -568,11 +550,6 @@ static LogicalResult verify(DeallocOp op) {
   return success();
 }
 
-void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                            MLIRContext *context) {
-  results.add<SimplifyDeadDealloc>(context);
-}
-
 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
                               SmallVectorImpl<OpFoldResult> &results) {
   /// dealloc(memrefcast) -> dealloc

diff  --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index d766dd9956f8..7f528fb26fe5 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -93,11 +93,8 @@ func @compose_affine_maps_1dto2d_with_symbols() {
 // CHECK-DAG: #[[$MAP8:.*]] = affine_map<(d0, d1) -> (d1 + (d0 ceildiv 4) * 4 - (d1 floordiv 4) * 4)>
 // CHECK-DAG: #[[$MAP8a:.*]] = affine_map<(d0, d1) -> (d1 + (d0 ceildiv 8) * 8 - (d1 floordiv 8) * 8)>
 
-// CHECK-LABEL: func @compose_affine_maps_2d_tile() {
-func @compose_affine_maps_2d_tile() {
-  %0 = memref.alloc() : memref<16x32xf32>
-  %1 = memref.alloc() : memref<16x32xf32>
-
+// CHECK-LABEL: func @compose_affine_maps_2d_tile
+func @compose_affine_maps_2d_tile(%0: memref<16x32xf32>, %1: memref<16x32xf32>) {
   %c4 = constant 4 : index
   %c8 = constant 8 : index
 
@@ -221,7 +218,7 @@ func @compose_affine_maps_multiple_symbols(%arg0: index, %arg1: index) -> index
 // -----
 
 // CHECK-LABEL: func @arg_used_as_dim_and_symbol
-func @arg_used_as_dim_and_symbol(%arg0: memref<100x100xf32>, %arg1: index, %arg2: f32) {
+func @arg_used_as_dim_and_symbol(%arg0: memref<100x100xf32>, %arg1: index, %arg2: f32) -> (memref<100x100xf32, 1>, memref<1xi32>) {
   %c9 = constant 9 : index
   %1 = memref.alloc() : memref<100x100xf32, 1>
   %2 = memref.alloc() : memref<1xi32>
@@ -235,7 +232,7 @@ func @arg_used_as_dim_and_symbol(%arg0: memref<100x100xf32>, %arg1: index, %arg2
       memref.store %arg2, %1[%4, %arg1] : memref<100x100xf32, 1>
     }
   }
-  return
+  return %1, %2 : memref<100x100xf32, 1>, memref<1xi32>
 }
 
 // -----

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index c8ad16ab9b14..be22f323873e 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -207,9 +207,8 @@ func @extract_from_tensor.generate_2d(%idx0: index, %idx1: index, %tensor: tenso
 
 // CHECK-LABEL: func @extract_from_tensor.generate_sideeffects
 // CHECK-SAME: %[[IDX:.*]]: index
-func @extract_from_tensor.generate_sideeffects(%idx: index, %tensor: tensor<*xf32>) -> index {
+func @extract_from_tensor.generate_sideeffects(%idx: index, %tensor: tensor<*xf32>, %mem: memref<?xindex>) -> index {
   %size = rank %tensor : tensor<*xf32>
-  %mem = memref.alloc(%size) : memref<?xindex>
   // CHECK: %[[DTENSOR:.*]] = tensor.generate
   %0 = tensor.generate %size {
     ^bb0(%arg0: index):

diff  --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 988f7ba3bfc1..49a35d162b2b 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -410,6 +410,27 @@ func @dead_dealloc_fold_multi_use(%cond : i1) {
   return
 }
 
+// CHECK-LABEL: func @write_only_alloc_fold
+func @write_only_alloc_fold(%v: f32) {
+  // CHECK-NEXT: return
+  %c0 = constant 0 : index
+  %c4 = constant 4 : index
+  %a = memref.alloc(%c4) : memref<?xf32>
+  memref.store %v, %a[%c0] : memref<?xf32>
+  memref.dealloc %a: memref<?xf32>
+  return
+}
+
+// CHECK-LABEL: func @write_only_alloca_fold
+func @write_only_alloca_fold(%v: f32) {
+  // CHECK-NEXT: return
+  %c0 = constant 0 : index
+  %c4 = constant 4 : index
+  %a = memref.alloca(%c4) : memref<?xf32>
+  memref.store %v, %a[%c0] : memref<?xf32>
+  return
+}
+
 // CHECK-LABEL: func @dead_block_elim
 func @dead_block_elim() {
   // CHECK-NOT: ^bb
@@ -426,7 +447,7 @@ func @dead_block_elim() {
 }
 
 // CHECK-LABEL: func @dyn_shape_fold(%arg0: index, %arg1: index)
-func @dyn_shape_fold(%L : index, %M : index) -> (memref<? x ? x i32>, memref<? x ? x f32>, memref<4 x ? x 8 x ? x ? x f32>) {
+func @dyn_shape_fold(%L : index, %M : index) -> (memref<4 x ? x 8 x ? x ? x f32>, memref<? x ? x i32>, memref<? x ? x f32>, memref<4 x ? x 8 x ? x ? x f32>) {
   // CHECK: %c0 = constant 0 : index
   %zero = constant 0 : index
   // The constants below disappear after they propagate into shapes.
@@ -434,13 +455,13 @@ func @dyn_shape_fold(%L : index, %M : index) -> (memref<? x ? x i32>, memref<? x
   %N = constant 1024 : index
   %K = constant 512 : index
 
-  // CHECK-NEXT: memref.alloc(%arg0) : memref<?x1024xf32>
+  // CHECK: memref.alloc(%arg0) : memref<?x1024xf32>
   %a = memref.alloc(%L, %N) : memref<? x ? x f32>
 
-  // CHECK-NEXT: memref.alloc(%arg1) : memref<4x1024x8x512x?xf32>
+  // CHECK: memref.alloc(%arg1) : memref<4x1024x8x512x?xf32>
   %b = memref.alloc(%N, %K, %M) : memref<4 x ? x 8 x ? x ? x f32>
 
-  // CHECK-NEXT: memref.alloc() : memref<512x1024xi32>
+  // CHECK: memref.alloc() : memref<512x1024xi32>
   %c = memref.alloc(%K, %N) : memref<? x ? x i32>
 
   // CHECK: memref.alloc() : memref<9x9xf32>
@@ -460,7 +481,7 @@ func @dyn_shape_fold(%L : index, %M : index) -> (memref<? x ? x i32>, memref<? x
     }
   }
 
-  return %c, %d, %e : memref<? x ? x i32>, memref<? x ? x f32>, memref<4 x ? x 8 x ? x ? x f32>
+  return %b, %c, %d, %e : memref<4 x ? x 8 x ? x ? x f32>, memref<? x ? x i32>, memref<? x ? x f32>, memref<4 x ? x 8 x ? x ? x f32>
 }
 
 #map1 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>


        


More information about the Mlir-commits mailing list