[Mlir-commits] [mlir] [mlir][memref] Rewrite scalar `memref.copy` through reinterpret_cast into load/store (PR #186118)
ioana ghiban
llvmlistbot at llvm.org
Thu Mar 19 03:29:14 PDT 2026
================
@@ -0,0 +1,237 @@
+// RUN: mlir-opt -memref-elide-reinterpret-cast %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Positive tests
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func private @concat_zero_offset(
+// CHECK-SAME: %[[SRC:.*]]: memref<1x1xf32>
+// CHECK-SAME: %[[DST:.*]]: memref<1x108xf32>
+func.func private @concat_zero_offset(%src : memref<1x1xf32>,
+ %dst : memref<1x108xf32>) {
+ /// reinterpret_cast removed
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %dst
+ to offset: [0], sizes: [1, 1], strides: [1, 1]
+ : memref<1x108xf32> to memref<1x1xf32>
+
+ /// Ensure copy was replaced
+ // CHECK-NOT: memref.copy
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]] : memref<1x1xf32>
+ // CHECK: memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C0_0]]] : memref<1x108xf32>
+ memref.copy %src, %reinterpret_cast
+ : memref<1x1xf32> to memref<1x1xf32>
+ return
+}
+
+// CHECK-LABEL: func.func private @concat_nonzero_offset(
+// CHECK-SAME: %[[SRC:.*]]: memref<1x1xf32>
+// CHECK-SAME: %[[DST:.*]]: memref<1x108xf32>
+func.func private @concat_nonzero_offset(%src : memref<1x1xf32>,
+ %dst : memref<1x108xf32>) {
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %dst
+ to offset: [1], sizes: [1, 1], strides: [1, 1]
+ : memref<1x108xf32>
+ to memref<1x1xf32, strided<[1, 1], offset: 1>>
+
+ // CHECK-NOT: memref.copy
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]] : memref<1x1xf32>
+ // CHECK: memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C1]]] : memref<1x108xf32>
+ memref.copy %src, %reinterpret_cast
+ : memref<1x1xf32>
+ to memref<1x1xf32, strided<[1, 1], offset: 1>>
+ return
+}
+
+// CHECK-LABEL: func.func private @concat_dynamic_offset(
+// CHECK-SAME: %[[OFF:.*]]: index
+// CHECK-SAME: %[[SRC:.*]]: memref<1x1xf32>
+// CHECK-SAME: %[[DST:.*]]: memref<1x108xf32>
+func.func private @concat_dynamic_offset(%offset: index, %src : memref<1x1xf32>,
+ %dst : memref<1x108xf32>) {
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %dst
+ to offset: [%offset], sizes: [1, 1], strides: [1, 1]
+ : memref<1x108xf32>
+ to memref<1x1xf32, strided<[1, 1], offset: ?>>
+
+ // CHECK-NOT: memref.copy
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]]
+ // CHECK-SAME: : memref<1x1xf32>
+ /// Dynamic offset used in store
+ // CHECK: memref.store %[[VAL]], %[[DST]][%[[C0]], %[[OFF]]] : memref<1x108xf32>
+ memref.copy %src, %reinterpret_cast
+ : memref<1x1xf32>
+ to memref<1x1xf32, strided<[1, 1], offset: ?>>
+ return
+}
+
+// CHECK-LABEL: func.func private @concat_strided(
+// CHECK-SAME: %[[SRC:.*]]: memref<1x1xf32>
+// CHECK-SAME: %[[DST:.*]]: memref<1x108xf32>
+func.func private @concat_strided(%src : memref<1x1xf32>,
+ %dst : memref<1x108xf32>) {
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %dst
+ to offset: [0], sizes: [1, 1], strides: [107, 2]
+ : memref<1x108xf32> to memref<1x1xf32, strided<[107, 2]>>
+
+ // CHECK-NOT: memref.copy
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]] : memref<1x1xf32>
+ // CHECK: memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C0_0]]] : memref<1x108xf32>
+ memref.copy %src, %reinterpret_cast
+ : memref<1x1xf32> to memref<1x1xf32, strided<[107, 2]>>
+ return
+}
+
+// CHECK-LABEL: func.func private @concat_dynamic_stride(
+// CHECK-SAME: %[[STR0:[A-Za-z][A-Za-z0-9-]*]]: index
+// CHECK-SAME: %[[STR1:[A-Za-z][A-Za-z0-9-]*]]: index
+// CHECK-SAME: %[[SRC:[A-Za-z][A-Za-z0-9-]*]]: memref<1x1xf32>
+// CHECK-SAME: %[[DST:[A-Za-z][A-Za-z0-9-]*]]: memref<1x108xf32>
+func.func private @concat_dynamic_stride(%stride0: index,
+ %stride1: index, %src : memref<1x1xf32>, %dst : memref<1x108xf32>) {
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %dst
+ to offset: [0], sizes: [1, 1], strides: [%stride0, %stride1]
+ : memref<1x108xf32>
+ to memref<1x1xf32, strided<[?, ?]>>
+
+ // CHECK-NOT: memref.copy
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]] : memref<1x1xf32>
+ /// Dynamic offset used in store
+ // CHECK: memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C0_0]]] : memref<1x108xf32>
+ memref.copy %src, %reinterpret_cast
+ : memref<1x1xf32>
+ to memref<1x1xf32, strided<[?, ?]>>
+ return
+}
+
+// CHECK-LABEL: func.func private @concat_rank1(
+// CHECK-SAME: %[[SRC:.*]]: memref<1xf32>
+// CHECK-SAME: %[[DST:.*]]: memref<108xf32>
+func.func private @concat_rank1(%src : memref<1xf32>, %dst : memref<108xf32>) {
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %dst
+ to offset: [0], sizes: [1], strides: [1]
+ : memref<108xf32> to memref<1xf32>
+
+ // CHECK-NOT: memref.copy
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]]] : memref<1xf32>
+ // CHECK: memref.store %[[VAL]], %[[DST]][%[[C0_0]]] : memref<108xf32>
+ memref.copy %src, %reinterpret_cast
+ : memref<1xf32> to memref<1xf32>
+ return
+}
+
+// CHECK-LABEL: func.func private @concat_rank3(
+// CHECK-SAME: %[[SRC:.*]]: memref<1x1x1xf32>
+// CHECK-SAME: %[[DST:.*]]: memref<1x1x108xf32>
+func.func private @concat_rank3(%src : memref<1x1x1xf32>,
+ %dst : memref<1x1x108xf32>) {
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %dst
+ to offset: [0], sizes: [1, 1, 1], strides: [1, 1, 1]
+ : memref<1x1x108xf32> to memref<1x1x1xf32>
+
+ // CHECK-NOT: memref.copy
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]], %[[C0]]] : memref<1x1x1xf32>
+ // CHECK: memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C0]], %[[C0_0]]] : memref<1x1x108xf32>
+ memref.copy %src, %reinterpret_cast
+ : memref<1x1x1xf32> to memref<1x1x1xf32>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// Negative tests (must NOT rewrite)
----------------
ioghiban wrote:
will do. Should I also remove this comment?
https://github.com/llvm/llvm-project/pull/186118
More information about the Mlir-commits
mailing list