[Mlir-commits] [mlir] f22d381 - [mlir] Canonicalize AllocOp's with only store and dealloc uses
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 23 23:55:08 PDT 2021
Author: Butygin
Date: 2021-04-24T09:51:00+03:00
New Revision: f22d3813850f9e87c5204df6844a93b8c5db7730
URL: https://github.com/llvm/llvm-project/commit/f22d3813850f9e87c5204df6844a93b8c5db7730
DIFF: https://github.com/llvm/llvm-project/commit/f22d3813850f9e87c5204df6844a93b8c5db7730.diff
LOG: [mlir] Canonicalize AllocOp's with only store and dealloc uses
Differential Revision: https://reviews.llvm.org/D100268
Added:
Modified:
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/test/Dialect/Affine/canonicalize.mlir
mlir/test/Dialect/Tensor/canonicalize.mlir
mlir/test/Transforms/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index fe0fd7d0ff36..7b341b1940cf 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -378,7 +378,6 @@ def MemRef_DeallocOp : MemRef_Op<"dealloc", [MemRefsNormalizable]> {
let arguments = (ins Arg<AnyMemRef, "", [MemFree]>:$memref);
- let hasCanonicalizer = 1;
let hasFolder = 1;
let assemblyFormat = "$memref attr-dict `:` type($memref)";
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 7e55f4c2877b..1ac00022e232 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -195,30 +195,36 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
}
};
-/// Fold alloc operations with no uses. Alloc has side effects on the heap,
-/// but can still be deleted if it has zero uses.
-struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> {
- using OpRewritePattern<AllocOp>::OpRewritePattern;
+/// Fold alloc operations with no users or only store and dealloc uses.
+template <typename T>
+struct SimplifyDeadAlloc : public OpRewritePattern<T> {
+ using OpRewritePattern<T>::OpRewritePattern;
- LogicalResult matchAndRewrite(AllocOp alloc,
+ LogicalResult matchAndRewrite(T alloc,
PatternRewriter &rewriter) const override {
- if (alloc.use_empty()) {
- rewriter.eraseOp(alloc);
- return success();
- }
- return failure();
+ if (llvm::any_of(alloc->getUsers(), [](Operation *op) {
+ return !isa<StoreOp, DeallocOp>(op);
+ }))
+ return failure();
+
+ for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
+ rewriter.eraseOp(user);
+
+ rewriter.eraseOp(alloc);
+ return success();
}
};
} // end anonymous namespace.
void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc>(context);
+ results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
}
void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<SimplifyAllocConst<AllocaOp>>(context);
+ results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
+ context);
}
//===----------------------------------------------------------------------===//
@@ -537,30 +543,6 @@ OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
// DeallocOp
//===----------------------------------------------------------------------===//
-namespace {
-/// Fold Dealloc operations that are deallocating an AllocOp that is only used
-/// by other Dealloc operations.
-struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> {
- using OpRewritePattern<DeallocOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(DeallocOp dealloc,
- PatternRewriter &rewriter) const override {
- // Check that the memref operand's defining operation is an AllocOp.
- Value memref = dealloc.memref();
- if (!isa_and_nonnull<AllocOp>(memref.getDefiningOp()))
- return failure();
-
- // Check that all of the uses of the AllocOp are other DeallocOps.
- for (auto *user : memref.getUsers())
- if (!isa<DeallocOp>(user))
- return failure();
-
- // Erase the dealloc operation.
- rewriter.eraseOp(dealloc);
- return success();
- }
-};
-} // end anonymous namespace.
static LogicalResult verify(DeallocOp op) {
if (!op.memref().getType().isa<MemRefType>())
@@ -568,11 +550,6 @@ static LogicalResult verify(DeallocOp op) {
return success();
}
-void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<SimplifyDeadDealloc>(context);
-}
-
LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results) {
/// dealloc(memrefcast) -> dealloc
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index d766dd9956f8..7f528fb26fe5 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -93,11 +93,8 @@ func @compose_affine_maps_1dto2d_with_symbols() {
// CHECK-DAG: #[[$MAP8:.*]] = affine_map<(d0, d1) -> (d1 + (d0 ceildiv 4) * 4 - (d1 floordiv 4) * 4)>
// CHECK-DAG: #[[$MAP8a:.*]] = affine_map<(d0, d1) -> (d1 + (d0 ceildiv 8) * 8 - (d1 floordiv 8) * 8)>
-// CHECK-LABEL: func @compose_affine_maps_2d_tile() {
-func @compose_affine_maps_2d_tile() {
- %0 = memref.alloc() : memref<16x32xf32>
- %1 = memref.alloc() : memref<16x32xf32>
-
+// CHECK-LABEL: func @compose_affine_maps_2d_tile
+func @compose_affine_maps_2d_tile(%0: memref<16x32xf32>, %1: memref<16x32xf32>) {
%c4 = constant 4 : index
%c8 = constant 8 : index
@@ -221,7 +218,7 @@ func @compose_affine_maps_multiple_symbols(%arg0: index, %arg1: index) -> index
// -----
// CHECK-LABEL: func @arg_used_as_dim_and_symbol
-func @arg_used_as_dim_and_symbol(%arg0: memref<100x100xf32>, %arg1: index, %arg2: f32) {
+func @arg_used_as_dim_and_symbol(%arg0: memref<100x100xf32>, %arg1: index, %arg2: f32) -> (memref<100x100xf32, 1>, memref<1xi32>) {
%c9 = constant 9 : index
%1 = memref.alloc() : memref<100x100xf32, 1>
%2 = memref.alloc() : memref<1xi32>
@@ -235,7 +232,7 @@ func @arg_used_as_dim_and_symbol(%arg0: memref<100x100xf32>, %arg1: index, %arg2
memref.store %arg2, %1[%4, %arg1] : memref<100x100xf32, 1>
}
}
- return
+ return %1, %2 : memref<100x100xf32, 1>, memref<1xi32>
}
// -----
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index c8ad16ab9b14..be22f323873e 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -207,9 +207,8 @@ func @extract_from_tensor.generate_2d(%idx0: index, %idx1: index, %tensor: tenso
// CHECK-LABEL: func @extract_from_tensor.generate_sideeffects
// CHECK-SAME: %[[IDX:.*]]: index
-func @extract_from_tensor.generate_sideeffects(%idx: index, %tensor: tensor<*xf32>) -> index {
+func @extract_from_tensor.generate_sideeffects(%idx: index, %tensor: tensor<*xf32>, %mem: memref<?xindex>) -> index {
%size = rank %tensor : tensor<*xf32>
- %mem = memref.alloc(%size) : memref<?xindex>
// CHECK: %[[DTENSOR:.*]] = tensor.generate
%0 = tensor.generate %size {
^bb0(%arg0: index):
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 988f7ba3bfc1..49a35d162b2b 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -410,6 +410,27 @@ func @dead_dealloc_fold_multi_use(%cond : i1) {
return
}
+// CHECK-LABEL: func @write_only_alloc_fold
+func @write_only_alloc_fold(%v: f32) {
+ // CHECK-NEXT: return
+ %c0 = constant 0 : index
+ %c4 = constant 4 : index
+ %a = memref.alloc(%c4) : memref<?xf32>
+ memref.store %v, %a[%c0] : memref<?xf32>
+ memref.dealloc %a: memref<?xf32>
+ return
+}
+
+// CHECK-LABEL: func @write_only_alloca_fold
+func @write_only_alloca_fold(%v: f32) {
+ // CHECK-NEXT: return
+ %c0 = constant 0 : index
+ %c4 = constant 4 : index
+ %a = memref.alloca(%c4) : memref<?xf32>
+ memref.store %v, %a[%c0] : memref<?xf32>
+ return
+}
+
// CHECK-LABEL: func @dead_block_elim
func @dead_block_elim() {
// CHECK-NOT: ^bb
@@ -426,7 +447,7 @@ func @dead_block_elim() {
}
// CHECK-LABEL: func @dyn_shape_fold(%arg0: index, %arg1: index)
-func @dyn_shape_fold(%L : index, %M : index) -> (memref<? x ? x i32>, memref<? x ? x f32>, memref<4 x ? x 8 x ? x ? x f32>) {
+func @dyn_shape_fold(%L : index, %M : index) -> (memref<4 x ? x 8 x ? x ? x f32>, memref<? x ? x i32>, memref<? x ? x f32>, memref<4 x ? x 8 x ? x ? x f32>) {
// CHECK: %c0 = constant 0 : index
%zero = constant 0 : index
// The constants below disappear after they propagate into shapes.
@@ -434,13 +455,13 @@ func @dyn_shape_fold(%L : index, %M : index) -> (memref<? x ? x i32>, memref<? x
%N = constant 1024 : index
%K = constant 512 : index
- // CHECK-NEXT: memref.alloc(%arg0) : memref<?x1024xf32>
+ // CHECK: memref.alloc(%arg0) : memref<?x1024xf32>
%a = memref.alloc(%L, %N) : memref<? x ? x f32>
- // CHECK-NEXT: memref.alloc(%arg1) : memref<4x1024x8x512x?xf32>
+ // CHECK: memref.alloc(%arg1) : memref<4x1024x8x512x?xf32>
%b = memref.alloc(%N, %K, %M) : memref<4 x ? x 8 x ? x ? x f32>
- // CHECK-NEXT: memref.alloc() : memref<512x1024xi32>
+ // CHECK: memref.alloc() : memref<512x1024xi32>
%c = memref.alloc(%K, %N) : memref<? x ? x i32>
// CHECK: memref.alloc() : memref<9x9xf32>
@@ -460,7 +481,7 @@ func @dyn_shape_fold(%L : index, %M : index) -> (memref<? x ? x i32>, memref<? x
}
}
- return %c, %d, %e : memref<? x ? x i32>, memref<? x ? x f32>, memref<4 x ? x 8 x ? x ? x f32>
+ return %b, %c, %d, %e : memref<4 x ? x 8 x ? x ? x f32>, memref<? x ? x i32>, memref<? x ? x f32>, memref<4 x ? x 8 x ? x ? x f32>
}
#map1 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
More information about the Mlir-commits
mailing list