[Mlir-commits] [mlir] 96acdfa - [mlir][memref] Fold copy of cast

Matthias Springer llvmlistbot at llvm.org
Fri Jan 14 04:56:19 PST 2022


Author: Matthias Springer
Date: 2022-01-14T21:51:12+09:00
New Revision: 96acdfa0de3e574453d3f62aa7a71397ac0120bb

URL: https://github.com/llvm/llvm-project/commit/96acdfa0de3e574453d3f62aa7a71397ac0120bb
DIFF: https://github.com/llvm/llvm-project/commit/96acdfa0de3e574453d3f62aa7a71397ac0120bb.diff

LOG: [mlir][memref] Fold copy of cast

If the source/dest is a cast that does not change shape/element type, the cast can be skipped.

Differential Revision: https://reviews.llvm.org/D117215

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
    mlir/test/Dialect/MemRef/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index ebce7276060f7..49f716dbf9b23 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -406,6 +406,7 @@ def CopyOp : MemRef_Op<"copy",
     $source `,` $target attr-dict `:` type($source) `to` type($target)
   }];
 
+  let hasCanonicalizer = 1;
   let verifier = ?;
 }
 

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 42719eedffa73..11e8de603062e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -438,6 +438,61 @@ OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
 }
 
+//===----------------------------------------------------------------------===//
+// CopyOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// If the source/target of a CopyOp is a CastOp that does not modify the shape
+/// and element type, the cast can be skipped. Such CastOps only cast the layout
+/// of the type.
+struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
+  using OpRewritePattern<CopyOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(CopyOp copyOp,
+                                PatternRewriter &rewriter) const override {
+    bool modified = false;
+
+    // Check source.
+    if (auto castOp = copyOp.source().getDefiningOp<CastOp>()) {
+      auto fromType = castOp.source().getType().dyn_cast<MemRefType>();
+      auto toType = castOp.source().getType().dyn_cast<MemRefType>();
+
+      if (fromType && toType) {
+        if (fromType.getShape() == toType.getShape() &&
+            fromType.getElementType() == toType.getElementType()) {
+          rewriter.updateRootInPlace(
+              copyOp, [&] { copyOp.sourceMutable().assign(castOp.source()); });
+          modified = true;
+        }
+      }
+    }
+
+    // Check target.
+    if (auto castOp = copyOp.target().getDefiningOp<CastOp>()) {
+      auto fromType = castOp.source().getType().dyn_cast<MemRefType>();
+      auto toType = castOp.source().getType().dyn_cast<MemRefType>();
+
+      if (fromType && toType) {
+        if (fromType.getShape() == toType.getShape() &&
+            fromType.getElementType() == toType.getElementType()) {
+          rewriter.updateRootInPlace(
+              copyOp, [&] { copyOp.targetMutable().assign(castOp.source()); });
+          modified = true;
+        }
+      }
+    }
+
+    return success(modified);
+  }
+};
+} // namespace
+
+void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                         MLIRContext *context) {
+  results.add<FoldCopyOfCast>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // DeallocOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
index fd32430f4e332..6db938f72323a 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
@@ -31,7 +31,7 @@ func @return_tensor(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
   // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
   // CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
   // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[casted]]
-  // CHECK: memref.copy %[[A_memref]], %[[casted]]
+  // CHECK: memref.copy %[[A_memref]], %[[alloc]]
   // CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
   %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
 
@@ -95,4 +95,4 @@ func @rank_reducing(
     scf.yield %10 : tensor<?x1x6x8xf32>
   }
   return %5: tensor<?x1x6x8xf32>
-}
\ No newline at end of file
+}

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
index 967b231d73134..e8f484c263261 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
@@ -159,7 +159,7 @@ func @simple_tensor_test(%t1 : tensor<?xf32>, %f : f32) -> tensor<?xf32> {
   // CHECK-TENSOR: %[[alloc:.*]] = memref.alloc
   // CHECK-TENSOR: %[[casted:.*]] = memref.cast %[[alloc]]
   // CHECK-TENSOR: %[[casted_tensor:.*]] = bufferization.to_tensor %[[casted]]
-  // CHECK-TENSOR: memref.copy %[[t1_memref]], %[[casted]]
+  // CHECK-TENSOR: memref.copy %[[t1_memref]], %[[alloc]]
   // CHECK-TENSOR: memref.store %{{.*}}, %[[alloc]]
   %0 = tensor.insert %f into %t1[%c0] : tensor<?xf32>
   // CHECK-TENSOR: return %[[casted_tensor]]
@@ -177,7 +177,7 @@ func @simple_scf_for(
   // CHECK-SCF: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]]
   // CHECK-SCF: %[[alloc:.*]] = memref.alloc
   // CHECK-SCF: %[[casted:.*]] = memref.cast %[[alloc]]
-  // CHECK-SCF: memref.copy %[[t1_memref]], %[[casted]]
+  // CHECK-SCF: memref.copy %[[t1_memref]], %[[alloc]]
   // CHECK-SCF: %[[scf_for:.*]] = scf.for %[[iv:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[arg0:.*]] = %[[casted]]) -> ({{.*}}) {
   %0 = scf.for %iv = %c0 to %sz step %step iter_args(%arg0 = %t1) -> tensor<?xf32> {
     // CHECK-SCF: %[[arg0_tensor:.*]] = bufferization.to_tensor %[[arg0]]

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 6ddb49de9932e..092f527a15af3 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -510,3 +510,18 @@ func @atomicrmw_cast_fold(%arg0 : f32, %arg1 : memref<4xf32>, %c : index) {
 
 // CHECK-LABEL: func @atomicrmw_cast_fold
 // CHECK-NEXT: memref.atomic_rmw addf %arg0, %arg1[%arg2] : (f32, memref<4xf32>) -> f32
+
+// -----
+
+#map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+func @copy_of_cast(%m1: memref<?xf32>, %m2: memref<*xf32>) {
+  %casted1 = memref.cast %m1 : memref<?xf32> to memref<?xf32, #map>
+  %casted2 = memref.cast %m2 : memref<*xf32> to memref<?xf32, #map>
+  memref.copy %casted1, %casted2 : memref<?xf32, #map> to memref<?xf32, #map>
+  return
+}
+
+// CHECK-LABEL: func @copy_of_cast(
+//  CHECK-SAME:     %[[m1:.*]]: memref<?xf32>, %[[m2:.*]]: memref<*xf32>
+//       CHECK:   %[[casted2:.*]] = memref.cast %[[m2]]
+//       CHECK:   memref.copy %[[m1]], %[[casted2]]


        


More information about the Mlir-commits mailing list