[Mlir-commits] [mlir] 0c782c2 - [mlir] Add folding of memref_cast inside another memref_cast
Alex Zinenko
llvmlistbot at llvm.org
Fri Nov 6 01:42:48 PST 2020
Author: Alex Zinenko
Date: 2020-11-06T10:42:40+01:00
New Revision: 0c782c214b56102161c03ec5657455b7c73a5722
URL: https://github.com/llvm/llvm-project/commit/0c782c214b56102161c03ec5657455b7c73a5722
DIFF: https://github.com/llvm/llvm-project/commit/0c782c214b56102161c03ec5657455b7c73a5722.diff
LOG: [mlir] Add folding of memref_cast inside another memref_cast
There exists a generic folding facility that folds the operand of a memref_cast
into users of memref_cast that support this. However, it was not used for the
memref_cast itself. Fix it to enable elimination of memref_cast chains such as
%1 = memref_cast %0 : A to B
%2 = memref_cast %1 : B to A
that is achieved by combining the folding with the existing "A to A" cast
elimination.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D90910
Added:
Modified:
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Transforms/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 9b5875e70793..d333ddc8e34c 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2386,7 +2386,9 @@ bool MemRefCastOp::areCastCompatible(Type a, Type b) {
}
OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) {
- return impl::foldCastOp(*this);
+ if (Value folded = impl::foldCastOp(*this))
+ return folded;
+ return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 7b8c45cb409b..08f3ac702596 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -334,6 +334,29 @@ func @memref_cast_folding(%arg0: memref<4 x f32>, %arg1: f32) -> (f32, f32) {
return %1, %2 : f32, f32
}
+// CHECK-LABEL: @fold_memref_cast_in_memref_cast
+// CHECK-SAME: (%[[ARG0:.*]]: memref<42x42xf64>)
+func @fold_memref_cast_in_memref_cast(%0: memref<42x42xf64>) {
+ // CHECK: %[[folded:.*]] = memref_cast %[[ARG0]] : memref<42x42xf64> to memref<?x?xf64>
+ %4 = memref_cast %0 : memref<42x42xf64> to memref<?x42xf64>
+ // CHECK-NOT: memref_cast
+ %5 = memref_cast %4 : memref<?x42xf64> to memref<?x?xf64>
+ // CHECK: "test.user"(%[[folded]])
+ "test.user"(%5) : (memref<?x?xf64>) -> ()
+ return
+}
+
+// CHECK-LABEL: @fold_memref_cast_chain
+// CHECK-SAME: (%[[ARG0:.*]]: memref<42x42xf64>)
+func @fold_memref_cast_chain(%0: memref<42x42xf64>) {
+ // CHECK-NOT: memref_cast
+ %4 = memref_cast %0 : memref<42x42xf64> to memref<?x42xf64>
+ %5 = memref_cast %4 : memref<?x42xf64> to memref<42x42xf64>
+ // CHECK: "test.user"(%[[ARG0]])
+ "test.user"(%5) : (memref<42x42xf64>) -> ()
+ return
+}
+
// CHECK-LABEL: func @alloc_const_fold
func @alloc_const_fold() -> memref<?xf32> {
// CHECK-NEXT: %0 = alloc() : memref<4xf32>
More information about the Mlir-commits
mailing list