[Mlir-commits] [mlir] 1f07bfb - [mlir][tensor] Implement folding logic for size 0 tensor and memref ops (#90814)
    llvmlistbot at llvm.org 
    llvmlistbot at llvm.org
       
    Mon May 20 13:36:48 PDT 2024
    
    
  
Author: Spenser Bauman
Date: 2024-05-20T16:36:45-04:00
New Revision: 1f07bfb92c2a62731a5ae3ec2d135e3869634c01
URL: https://github.com/llvm/llvm-project/commit/1f07bfb92c2a62731a5ae3ec2d135e3869634c01
DIFF: https://github.com/llvm/llvm-project/commit/1f07bfb92c2a62731a5ae3ec2d135e3869634c01.diff
LOG: [mlir][tensor] Implement folding logic for size 0 tensor and memref ops (#90814)
Implement folding and rewrite logic to eliminate no-op tensor and memref
operations. This handles two specific cases:
1. tensor.insert_slice operations where the size of the inserted slice
is known to be 0.
2. memref.copy operations where either the source or target memrefs are
known to be emtpy.
Co-authored-by: Spenser Bauman <sabauma at fastmail>
Added: 
    
Modified: 
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/MemRef/canonicalize.mlir
    mlir/test/Dialect/Tensor/canonicalize.mlir
Removed: 
    
################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 45f39c80041c9..d70e6d0b79cd6 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -833,11 +833,31 @@ struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
     return success();
   }
 };
+
+struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
+  using OpRewritePattern<CopyOp>::OpRewritePattern;
+
+  static bool isEmptyMemRef(BaseMemRefType type) {
+    return type.hasRank() &&
+           llvm::any_of(type.getShape(), [](int64_t x) { return x == 0; });
+  }
+
+  LogicalResult matchAndRewrite(CopyOp copyOp,
+                                PatternRewriter &rewriter) const override {
+    if (isEmptyMemRef(copyOp.getSource().getType()) ||
+        isEmptyMemRef(copyOp.getTarget().getType())) {
+      rewriter.eraseOp(copyOp);
+      return success();
+    }
+
+    return failure();
+  }
+};
 } // namespace
 
 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                          MLIRContext *context) {
-  results.add<FoldCopyOfCast, FoldSelfCopy>(context);
+  results.add<FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy>(context);
 }
 
 LogicalResult CopyOp::fold(FoldAdaptor adaptor,
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 8a6df82abb312..8545c7b9af8f7 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2609,6 +2609,9 @@ OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
     return getResult();
   if (auto result = foldInsertAfterExtractSlice(*this))
     return result;
+  if (llvm::any_of(getMixedSizes(),
+                   [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }))
+    return getDest();
   return OpFoldResult();
 }
 
diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index f442a61dc31ed..c4ff6480a4ce5 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -692,6 +692,16 @@ func.func @self_copy(%m1: memref<?xf32>) {
 
 // -----
 
+// CHECK-LABEL: func @empty_copy
+//  CHECK-NEXT:   return
+func.func @empty_copy(%m1: memref<0x10xf32>, %m2: memref<?x10xf32>) {
+  memref.copy %m1, %m2 : memref<0x10xf32> to memref<?x10xf32>
+  memref.copy %m2, %m1 : memref<?x10xf32> to memref<0x10xf32>
+  return
+}
+
+// -----
+
 func.func @scopeMerge() {
   memref.alloca_scope {
     %cnt = "test.count"() : () -> index
diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index b5a82eb3e9035..914e5e8b8c4b8 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -542,6 +542,18 @@ func.func @trivial_insert_slice(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<4x6
 
 // -----
 
+// CHECK-LABEL: func @empty_insert_slice
+//  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<0x2xi8>
+//  CHECK-SAME:   %[[ARG1:.[a-z0-9A-Z_]+]]: tensor<3x3xi8>
+//   CHECK-NOT:   tensor.extract_slice
+//       CHECK:   return %[[ARG1]] :  tensor<3x3xi8>
+func.func @empty_insert_slice(%arg0 : tensor<0x2xi8>, %arg1 : tensor<3x3xi8>) -> tensor<3x3xi8> {
+  %0 = tensor.insert_slice %arg0 into %arg1[0, 0] [0, 2] [1, 1] : tensor<0x2xi8> into tensor<3x3xi8>
+  return %0 : tensor<3x3xi8>
+}
+
+// -----
+
 // CHECK-LABEL: func @rank_reducing_tensor_of_cast
 //  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
 //       CHECK:   %[[S:.+]] = tensor.extract_slice %arg0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> to tensor<16x32xi8>
        
    
    
More information about the Mlir-commits
mailing list