[Mlir-commits] [mlir] 96acdfa - [mlir][memref] Fold copy of cast
Matthias Springer
llvmlistbot at llvm.org
Fri Jan 14 04:56:19 PST 2022
Author: Matthias Springer
Date: 2022-01-14T21:51:12+09:00
New Revision: 96acdfa0de3e574453d3f62aa7a71397ac0120bb
URL: https://github.com/llvm/llvm-project/commit/96acdfa0de3e574453d3f62aa7a71397ac0120bb
DIFF: https://github.com/llvm/llvm-project/commit/96acdfa0de3e574453d3f62aa7a71397ac0120bb.diff
LOG: [mlir][memref] Fold copy of cast
If the source/dest is a cast that does not change shape/element type, the cast can be skipped.
Differential Revision: https://reviews.llvm.org/D117215
Added:
Modified:
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
mlir/test/Dialect/MemRef/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index ebce7276060f7..49f716dbf9b23 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -406,6 +406,7 @@ def CopyOp : MemRef_Op<"copy",
$source `,` $target attr-dict `:` type($source) `to` type($target)
}];
+ let hasCanonicalizer = 1;
let verifier = ?;
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 42719eedffa73..11e8de603062e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -438,6 +438,61 @@ OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
}
+//===----------------------------------------------------------------------===//
+// CopyOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// If the source/target of a CopyOp is a CastOp that does not modify the shape
+/// and element type, the cast can be skipped. Such CastOps only cast the layout
+/// of the type.
+struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
+ using OpRewritePattern<CopyOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(CopyOp copyOp,
+ PatternRewriter &rewriter) const override {
+ bool modified = false;
+
+ // Check source.
+ if (auto castOp = copyOp.source().getDefiningOp<CastOp>()) {
+ auto fromType = castOp.source().getType().dyn_cast<MemRefType>();
+ auto toType = castOp.source().getType().dyn_cast<MemRefType>();
+
+ if (fromType && toType) {
+ if (fromType.getShape() == toType.getShape() &&
+ fromType.getElementType() == toType.getElementType()) {
+ rewriter.updateRootInPlace(
+ copyOp, [&] { copyOp.sourceMutable().assign(castOp.source()); });
+ modified = true;
+ }
+ }
+ }
+
+ // Check target.
+ if (auto castOp = copyOp.target().getDefiningOp<CastOp>()) {
+ auto fromType = castOp.source().getType().dyn_cast<MemRefType>();
+ auto toType = castOp.source().getType().dyn_cast<MemRefType>();
+
+ if (fromType && toType) {
+ if (fromType.getShape() == toType.getShape() &&
+ fromType.getElementType() == toType.getElementType()) {
+ rewriter.updateRootInPlace(
+ copyOp, [&] { copyOp.targetMutable().assign(castOp.source()); });
+ modified = true;
+ }
+ }
+ }
+
+ return success(modified);
+ }
+};
+} // namespace
+
+void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<FoldCopyOfCast>(context);
+}
+
//===----------------------------------------------------------------------===//
// DeallocOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
index fd32430f4e332..6db938f72323a 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
@@ -31,7 +31,7 @@ func @return_tensor(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
// CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
// CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[casted]]
- // CHECK: memref.copy %[[A_memref]], %[[casted]]
+ // CHECK: memref.copy %[[A_memref]], %[[alloc]]
// CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
%0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
@@ -95,4 +95,4 @@ func @rank_reducing(
scf.yield %10 : tensor<?x1x6x8xf32>
}
return %5: tensor<?x1x6x8xf32>
-}
\ No newline at end of file
+}
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
index 967b231d73134..e8f484c263261 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
@@ -159,7 +159,7 @@ func @simple_tensor_test(%t1 : tensor<?xf32>, %f : f32) -> tensor<?xf32> {
// CHECK-TENSOR: %[[alloc:.*]] = memref.alloc
// CHECK-TENSOR: %[[casted:.*]] = memref.cast %[[alloc]]
// CHECK-TENSOR: %[[casted_tensor:.*]] = bufferization.to_tensor %[[casted]]
- // CHECK-TENSOR: memref.copy %[[t1_memref]], %[[casted]]
+ // CHECK-TENSOR: memref.copy %[[t1_memref]], %[[alloc]]
// CHECK-TENSOR: memref.store %{{.*}}, %[[alloc]]
%0 = tensor.insert %f into %t1[%c0] : tensor<?xf32>
// CHECK-TENSOR: return %[[casted_tensor]]
@@ -177,7 +177,7 @@ func @simple_scf_for(
// CHECK-SCF: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]]
// CHECK-SCF: %[[alloc:.*]] = memref.alloc
// CHECK-SCF: %[[casted:.*]] = memref.cast %[[alloc]]
- // CHECK-SCF: memref.copy %[[t1_memref]], %[[casted]]
+ // CHECK-SCF: memref.copy %[[t1_memref]], %[[alloc]]
// CHECK-SCF: %[[scf_for:.*]] = scf.for %[[iv:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[arg0:.*]] = %[[casted]]) -> ({{.*}}) {
%0 = scf.for %iv = %c0 to %sz step %step iter_args(%arg0 = %t1) -> tensor<?xf32> {
// CHECK-SCF: %[[arg0_tensor:.*]] = bufferization.to_tensor %[[arg0]]
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 6ddb49de9932e..092f527a15af3 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -510,3 +510,18 @@ func @atomicrmw_cast_fold(%arg0 : f32, %arg1 : memref<4xf32>, %c : index) {
// CHECK-LABEL: func @atomicrmw_cast_fold
// CHECK-NEXT: memref.atomic_rmw addf %arg0, %arg1[%arg2] : (f32, memref<4xf32>) -> f32
+
+// -----
+
+#map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+func @copy_of_cast(%m1: memref<?xf32>, %m2: memref<*xf32>) {
+ %casted1 = memref.cast %m1 : memref<?xf32> to memref<?xf32, #map>
+ %casted2 = memref.cast %m2 : memref<*xf32> to memref<?xf32, #map>
+ memref.copy %casted1, %casted2 : memref<?xf32, #map> to memref<?xf32, #map>
+ return
+}
+
+// CHECK-LABEL: func @copy_of_cast(
+// CHECK-SAME: %[[m1:.*]]: memref<?xf32>, %[[m2:.*]]: memref<*xf32>
+// CHECK: %[[casted2:.*]] = memref.cast %[[m2]]
+// CHECK: memref.copy %[[m1]], %[[casted2]]
More information about the Mlir-commits
mailing list