[Mlir-commits] [mlir] 1f07bfb - [mlir][tensor] Implement folding logic for size 0 tensor and memref ops (#90814)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 20 13:36:48 PDT 2024
Author: Spenser Bauman
Date: 2024-05-20T16:36:45-04:00
New Revision: 1f07bfb92c2a62731a5ae3ec2d135e3869634c01
URL: https://github.com/llvm/llvm-project/commit/1f07bfb92c2a62731a5ae3ec2d135e3869634c01
DIFF: https://github.com/llvm/llvm-project/commit/1f07bfb92c2a62731a5ae3ec2d135e3869634c01.diff
LOG: [mlir][tensor] Implement folding logic for size 0 tensor and memref ops (#90814)
Implement folding and rewrite logic to eliminate no-op tensor and memref
operations. This handles two specific cases:
1. tensor.insert_slice operations where the size of the inserted slice
is known to be 0.
2. memref.copy operations where either the source or target memrefs are
known to be emtpy.
Co-authored-by: Spenser Bauman <sabauma at fastmail>
Added:
Modified:
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/MemRef/canonicalize.mlir
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 45f39c80041c9..d70e6d0b79cd6 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -833,11 +833,31 @@ struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
return success();
}
};
+
+struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
+ using OpRewritePattern<CopyOp>::OpRewritePattern;
+
+ static bool isEmptyMemRef(BaseMemRefType type) {
+ return type.hasRank() &&
+ llvm::any_of(type.getShape(), [](int64_t x) { return x == 0; });
+ }
+
+ LogicalResult matchAndRewrite(CopyOp copyOp,
+ PatternRewriter &rewriter) const override {
+ if (isEmptyMemRef(copyOp.getSource().getType()) ||
+ isEmptyMemRef(copyOp.getTarget().getType())) {
+ rewriter.eraseOp(copyOp);
+ return success();
+ }
+
+ return failure();
+ }
+};
} // namespace
void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<FoldCopyOfCast, FoldSelfCopy>(context);
+ results.add<FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy>(context);
}
LogicalResult CopyOp::fold(FoldAdaptor adaptor,
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 8a6df82abb312..8545c7b9af8f7 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2609,6 +2609,9 @@ OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
return getResult();
if (auto result = foldInsertAfterExtractSlice(*this))
return result;
+ if (llvm::any_of(getMixedSizes(),
+ [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }))
+ return getDest();
return OpFoldResult();
}
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index f442a61dc31ed..c4ff6480a4ce5 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -692,6 +692,16 @@ func.func @self_copy(%m1: memref<?xf32>) {
// -----
+// CHECK-LABEL: func @empty_copy
+// CHECK-NEXT: return
+func.func @empty_copy(%m1: memref<0x10xf32>, %m2: memref<?x10xf32>) {
+ memref.copy %m1, %m2 : memref<0x10xf32> to memref<?x10xf32>
+ memref.copy %m2, %m1 : memref<?x10xf32> to memref<0x10xf32>
+ return
+}
+
+// -----
+
func.func @scopeMerge() {
memref.alloca_scope {
%cnt = "test.count"() : () -> index
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index b5a82eb3e9035..914e5e8b8c4b8 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -542,6 +542,18 @@ func.func @trivial_insert_slice(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<4x6
// -----
+// CHECK-LABEL: func @empty_insert_slice
+// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<0x2xi8>
+// CHECK-SAME: %[[ARG1:.[a-z0-9A-Z_]+]]: tensor<3x3xi8>
+// CHECK-NOT: tensor.extract_slice
+// CHECK: return %[[ARG1]] : tensor<3x3xi8>
+func.func @empty_insert_slice(%arg0 : tensor<0x2xi8>, %arg1 : tensor<3x3xi8>) -> tensor<3x3xi8> {
+ %0 = tensor.insert_slice %arg0 into %arg1[0, 0] [0, 2] [1, 1] : tensor<0x2xi8> into tensor<3x3xi8>
+ return %0 : tensor<3x3xi8>
+}
+
+// -----
+
// CHECK-LABEL: func @rank_reducing_tensor_of_cast
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
// CHECK: %[[S:.+]] = tensor.extract_slice %arg0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> to tensor<16x32xi8>
More information about the Mlir-commits
mailing list