[Mlir-commits] [mlir] [mlir][tensor] Implement folding logic for size 0 tensor and memref ops (PR #90814)
Spenser Bauman
llvmlistbot at llvm.org
Wed May 1 19:37:36 PDT 2024
https://github.com/sabauma created https://github.com/llvm/llvm-project/pull/90814
Implement folding and rewrite logic to eliminate no-op tensor and memref operations. This handles two specific cases:
1. tensor.insert_slice operations where the size of the inserted slice is known to be 0.
2. memref.copy operations where either the source or target memrefs are known to be emtpy.
>From 5ed96f68e63c4f8de69563693ffe28a557301217 Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sabauma at fastmail>
Date: Wed, 1 May 2024 19:30:54 -0400
Subject: [PATCH] Implement folding logic for size 0 tensor and memref ops
Implement folding and rewrite logic to eliminate no-op tensor and memref
operations. This handles two specific cases:
1. tensor.insert_slice operations where the size of the inserted slice
is known to be 0.
2. memref.copy operations where either the source or target memrefs are
known to be emtpy.
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 22 +++++++++++++++++++++-
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 3 +++
mlir/test/Dialect/MemRef/canonicalize.mlir | 10 ++++++++++
mlir/test/Dialect/Tensor/canonicalize.mlir | 12 ++++++++++++
4 files changed, 46 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b969d41d934d41..675aeacd8f0e23 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -833,11 +833,31 @@ struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
return success();
}
};
+
+struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
+ using OpRewritePattern<CopyOp>::OpRewritePattern;
+
+ static bool isEmptyMemRef(BaseMemRefType type) {
+ return type.hasRank() &&
+ llvm::any_of(type.getShape(), [](int64_t x) { return x == 0; });
+ }
+
+ LogicalResult matchAndRewrite(CopyOp copyOp,
+ PatternRewriter& rewriter) const override {
+ if (isEmptyMemRef(copyOp.getSource().getType()) ||
+ isEmptyMemRef(copyOp.getTarget().getType())) {
+ rewriter.eraseOp(copyOp);
+ return success();
+ }
+
+ return failure();
+ }
+};
} // namespace
void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<FoldCopyOfCast, FoldSelfCopy>(context);
+ results.add<FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy>(context);
}
LogicalResult CopyOp::fold(FoldAdaptor adaptor,
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 4c65045084dc5f..ef8a078078c864 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2606,6 +2606,9 @@ OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
return getResult();
if (auto result = foldInsertAfterExtractSlice(*this))
return result;
+ if (llvm::any_of(getMixedSizes(),
+ [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }))
+ return getDest();
return OpFoldResult();
}
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index f442a61dc31ed1..c4ff6480a4ce5e 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -692,6 +692,16 @@ func.func @self_copy(%m1: memref<?xf32>) {
// -----
+// CHECK-LABEL: func @empty_copy
+// CHECK-NEXT: return
+func.func @empty_copy(%m1: memref<0x10xf32>, %m2: memref<?x10xf32>) {
+ memref.copy %m1, %m2 : memref<0x10xf32> to memref<?x10xf32>
+ memref.copy %m2, %m1 : memref<?x10xf32> to memref<0x10xf32>
+ return
+}
+
+// -----
+
func.func @scopeMerge() {
memref.alloca_scope {
%cnt = "test.count"() : () -> index
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 6177fe3c752c93..e8adb7653c3e23 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -542,6 +542,18 @@ func.func @trivial_insert_slice(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<4x6
// -----
+// CHECK-LABEL: func @empty_insert_slice
+// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<0x2xi8>
+// CHECK-SAME: %[[ARG1:.[a-z0-9A-Z_]+]]: tensor<3x3xi8>
+// CHECK-NOT: tensor.extract_slice
+// CHECK: return %[[ARG1]] : tensor<3x3xi8>
+func.func @empty_insert_slice(%arg0 : tensor<0x2xi8>, %arg1 : tensor<3x3xi8>) -> tensor<3x3xi8> {
+ %0 = tensor.insert_slice %arg0 into %arg1[0, 0] [0, 2] [1, 1] : tensor<0x2xi8> into tensor<3x3xi8>
+ return %0 : tensor<3x3xi8>
+}
+
+// -----
+
// CHECK-LABEL: func @rank_reducing_tensor_of_cast
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
// CHECK: %[[S:.+]] = tensor.extract_slice %arg0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> to tensor<16x32xi8>
More information about the Mlir-commits
mailing list