[Mlir-commits] [mlir] 01f4390 - [MLIR] Fold memref.reinterpret_cast(x) -> x when the type is fully static and
Andrey Turetskiy
llvmlistbot at llvm.org
Wed Aug 30 20:55:39 PDT 2023
Author: Andrey Turetskiy
Date: 2023-08-30T20:50:18-07:00
New Revision: 01f4390a519f9990f4b1bf602c30dba6914f7ac3
URL: https://github.com/llvm/llvm-project/commit/01f4390a519f9990f4b1bf602c30dba6914f7ac3
DIFF: https://github.com/llvm/llvm-project/commit/01f4390a519f9990f4b1bf602c30dba6914f7ac3.diff
LOG: [MLIR] Fold memref.reinterpret_cast(x) -> x when the type is fully static and
does not change.
Differential Revision: https://reviews.llvm.org/D149296
Added:
Modified:
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/test/Dialect/MemRef/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 9c5c322e23692b..d08da74df47976 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1926,6 +1926,12 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
return getResult();
}
+ // reinterpret_cast(x) w/o offset/shape/stride changes -> x
+ if (!ShapedType::isDynamicShape(getType().getShape()) &&
+ src.getType() == getType() && getStaticOffsets().front() == 0) {
+ return src;
+ }
+
return nullptr;
}
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index df66705e83e0e2..c9f874e4cf3051 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -719,6 +719,16 @@ func.func @scopeInline(%arg : memref<index>) {
// -----
+// CHECK-LABEL: func @reinterpret_noop
+// CHECK-SAME: (%[[ARG:.*]]: memref<2x3x4xf32>)
+// CHECK-NEXT: return %[[ARG]]
+func.func @reinterpret_noop(%arg : memref<2x3x4xf32>) -> memref<2x3x4xf32> {
+ %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [2, 3, 4], strides: [12, 4, 1] : memref<2x3x4xf32> to memref<2x3x4xf32>
+ return %0 : memref<2x3x4xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @reinterpret_of_reinterpret
// CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index)
// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1]
More information about the Mlir-commits
mailing list