[Mlir-commits] [mlir] 01f4390 - [MLIR] Fold memref.reinterpret_cast(x) -> x when the type is fully static and

Andrey Turetskiy llvmlistbot at llvm.org
Wed Aug 30 20:55:39 PDT 2023


Author: Andrey Turetskiy
Date: 2023-08-30T20:50:18-07:00
New Revision: 01f4390a519f9990f4b1bf602c30dba6914f7ac3

URL: https://github.com/llvm/llvm-project/commit/01f4390a519f9990f4b1bf602c30dba6914f7ac3
DIFF: https://github.com/llvm/llvm-project/commit/01f4390a519f9990f4b1bf602c30dba6914f7ac3.diff

LOG: [MLIR] Fold memref.reinterpret_cast(x) -> x when the type is fully static and
does not change.

Differential Revision: https://reviews.llvm.org/D149296

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/MemRef/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 9c5c322e23692b..d08da74df47976 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1926,6 +1926,12 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
     return getResult();
   }
 
+  // reinterpret_cast(x) w/o offset/shape/stride changes -> x
+  if (!ShapedType::isDynamicShape(getType().getShape()) &&
+      src.getType() == getType() && getStaticOffsets().front() == 0) {
+    return src;
+  }
+
   return nullptr;
 }
 

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index df66705e83e0e2..c9f874e4cf3051 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -719,6 +719,16 @@ func.func @scopeInline(%arg : memref<index>) {
 
 // -----
 
+// CHECK-LABEL: func @reinterpret_noop
+//  CHECK-SAME: (%[[ARG:.*]]: memref<2x3x4xf32>)
+//  CHECK-NEXT: return %[[ARG]]
+func.func @reinterpret_noop(%arg : memref<2x3x4xf32>) -> memref<2x3x4xf32> {
+  %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [2, 3, 4], strides: [12, 4, 1] : memref<2x3x4xf32> to memref<2x3x4xf32>
+  return %0 : memref<2x3x4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @reinterpret_of_reinterpret
 //  CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index)
 //       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1]


        


More information about the Mlir-commits mailing list