[Mlir-commits] [mlir] 0c782c2 - [mlir] Add folding of memref_cast inside another memref_cast

Alex Zinenko llvmlistbot at llvm.org
Fri Nov 6 01:42:48 PST 2020


Author: Alex Zinenko
Date: 2020-11-06T10:42:40+01:00
New Revision: 0c782c214b56102161c03ec5657455b7c73a5722

URL: https://github.com/llvm/llvm-project/commit/0c782c214b56102161c03ec5657455b7c73a5722
DIFF: https://github.com/llvm/llvm-project/commit/0c782c214b56102161c03ec5657455b7c73a5722.diff

LOG: [mlir] Add folding of memref_cast inside another memref_cast

There exists a generic folding facility that folds the operand of a memref_cast
into users of memref_cast that support this. However, it was not used for the
memref_cast itself. Fix it to enable elimination of memref_cast chains such as

  %1 = memref_cast %0 : A to B
  %2 = memref_cast %1 : B to A

that is achieved by combining the folding with the existing "A to A" cast
elimination.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D90910

Added: 
    

Modified: 
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Transforms/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 9b5875e70793..d333ddc8e34c 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2386,7 +2386,9 @@ bool MemRefCastOp::areCastCompatible(Type a, Type b) {
 }
 
 OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) {
-  return impl::foldCastOp(*this);
+  if (Value folded = impl::foldCastOp(*this))
+    return folded;
+  return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 7b8c45cb409b..08f3ac702596 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -334,6 +334,29 @@ func @memref_cast_folding(%arg0: memref<4 x f32>, %arg1: f32) -> (f32, f32) {
   return %1, %2 : f32, f32
 }
 
+// CHECK-LABEL: @fold_memref_cast_in_memref_cast
+// CHECK-SAME: (%[[ARG0:.*]]: memref<42x42xf64>)
+func @fold_memref_cast_in_memref_cast(%0: memref<42x42xf64>) {
+  // CHECK: %[[folded:.*]] = memref_cast %[[ARG0]] : memref<42x42xf64> to memref<?x?xf64>
+  %4 = memref_cast %0 : memref<42x42xf64> to memref<?x42xf64>
+  // CHECK-NOT: memref_cast
+  %5 = memref_cast %4 : memref<?x42xf64> to memref<?x?xf64>
+  // CHECK: "test.user"(%[[folded]])
+  "test.user"(%5) : (memref<?x?xf64>) -> ()
+  return
+}
+
+// CHECK-LABEL: @fold_memref_cast_chain
+// CHECK-SAME: (%[[ARG0:.*]]: memref<42x42xf64>)
+func @fold_memref_cast_chain(%0: memref<42x42xf64>) {
+  // CHECK-NOT: memref_cast
+  %4 = memref_cast %0 : memref<42x42xf64> to memref<?x42xf64>
+  %5 = memref_cast %4 : memref<?x42xf64> to memref<42x42xf64>
+  // CHECK: "test.user"(%[[ARG0]])
+  "test.user"(%5) : (memref<42x42xf64>) -> ()
+  return
+}
+
 // CHECK-LABEL: func @alloc_const_fold
 func @alloc_const_fold() -> memref<?xf32> {
   // CHECK-NEXT: %0 = alloc() : memref<4xf32>


        


More information about the Mlir-commits mailing list