[Mlir-commits] [mlir] [mlir][tensor] Implement folding logic for size 0 tensor and memref ops (PR #90814)

Spenser Bauman llvmlistbot at llvm.org
Wed May 1 19:37:36 PDT 2024


https://github.com/sabauma created https://github.com/llvm/llvm-project/pull/90814

Implement folding and rewrite logic to eliminate no-op tensor and memref operations. This handles two specific cases:

1. tensor.insert_slice operations where the size of the inserted slice is known to be 0.
2. memref.copy operations where either the source or target memrefs are known to be emtpy.

>From 5ed96f68e63c4f8de69563693ffe28a557301217 Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sabauma at fastmail>
Date: Wed, 1 May 2024 19:30:54 -0400
Subject: [PATCH] Implement folding logic for size 0 tensor and memref ops

Implement folding and rewrite logic to eliminate no-op tensor and memref
operations. This handles two specific cases:

1. tensor.insert_slice operations where the size of the inserted slice
   is known to be 0.
2. memref.copy operations where either the source or target memrefs are
   known to be emtpy.
---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp   | 22 +++++++++++++++++++++-
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp   |  3 +++
 mlir/test/Dialect/MemRef/canonicalize.mlir | 10 ++++++++++
 mlir/test/Dialect/Tensor/canonicalize.mlir | 12 ++++++++++++
 4 files changed, 46 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b969d41d934d41..675aeacd8f0e23 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -833,11 +833,31 @@ struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
     return success();
   }
 };
+
+struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
+  using OpRewritePattern<CopyOp>::OpRewritePattern;
+
+  static bool isEmptyMemRef(BaseMemRefType type) {
+    return type.hasRank() &&
+      llvm::any_of(type.getShape(), [](int64_t x) { return x == 0; });
+  }
+
+  LogicalResult matchAndRewrite(CopyOp copyOp,
+                                PatternRewriter& rewriter) const override {
+    if (isEmptyMemRef(copyOp.getSource().getType()) ||
+        isEmptyMemRef(copyOp.getTarget().getType())) {
+      rewriter.eraseOp(copyOp);
+      return success();
+    }
+
+    return failure();
+  }
+};
 } // namespace
 
 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                          MLIRContext *context) {
-  results.add<FoldCopyOfCast, FoldSelfCopy>(context);
+  results.add<FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy>(context);
 }
 
 LogicalResult CopyOp::fold(FoldAdaptor adaptor,
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 4c65045084dc5f..ef8a078078c864 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2606,6 +2606,9 @@ OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
     return getResult();
   if (auto result = foldInsertAfterExtractSlice(*this))
     return result;
+  if (llvm::any_of(getMixedSizes(),
+                   [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }))
+    return getDest();
   return OpFoldResult();
 }
 
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index f442a61dc31ed1..c4ff6480a4ce5e 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -692,6 +692,16 @@ func.func @self_copy(%m1: memref<?xf32>) {
 
 // -----
 
+// CHECK-LABEL: func @empty_copy
+//  CHECK-NEXT:   return
+func.func @empty_copy(%m1: memref<0x10xf32>, %m2: memref<?x10xf32>) {
+  memref.copy %m1, %m2 : memref<0x10xf32> to memref<?x10xf32>
+  memref.copy %m2, %m1 : memref<?x10xf32> to memref<0x10xf32>
+  return
+}
+
+// -----
+
 func.func @scopeMerge() {
   memref.alloca_scope {
     %cnt = "test.count"() : () -> index
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 6177fe3c752c93..e8adb7653c3e23 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -542,6 +542,18 @@ func.func @trivial_insert_slice(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<4x6
 
 // -----
 
+// CHECK-LABEL: func @empty_insert_slice
+//  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<0x2xi8>
+//  CHECK-SAME:   %[[ARG1:.[a-z0-9A-Z_]+]]: tensor<3x3xi8>
+//   CHECK-NOT:   tensor.extract_slice
+//       CHECK:   return %[[ARG1]] :  tensor<3x3xi8>
+func.func @empty_insert_slice(%arg0 : tensor<0x2xi8>, %arg1 : tensor<3x3xi8>) -> tensor<3x3xi8> {
+  %0 = tensor.insert_slice %arg0 into %arg1[0, 0] [0, 2] [1, 1] : tensor<0x2xi8> into tensor<3x3xi8>
+  return %0 : tensor<3x3xi8>
+}
+
+// -----
+
 // CHECK-LABEL: func @rank_reducing_tensor_of_cast
 //  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
 //       CHECK:   %[[S:.+]] = tensor.extract_slice %arg0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> to tensor<16x32xi8>



More information about the Mlir-commits mailing list