[Mlir-commits] [mlir] [mlir][memref] Rewrite scalar `memref.copy` through reinterpret_cast into load/store (PR #186118)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Mar 19 03:59:15 PDT 2026


================
@@ -0,0 +1,237 @@
+// RUN: mlir-opt -memref-elide-reinterpret-cast %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Positive tests
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func private @concat_zero_offset(
+// CHECK-SAME:   %[[SRC:.*]]: memref<1x1xf32>
+// CHECK-SAME:   %[[DST:.*]]: memref<1x108xf32>
+func.func private @concat_zero_offset(%src : memref<1x1xf32>,
+  %dst : memref<1x108xf32>) {
+  /// reinterpret_cast removed
+  // CHECK-NOT:  memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %dst
+    to offset: [0], sizes: [1, 1], strides: [1, 1]
+    : memref<1x108xf32> to memref<1x1xf32>
+
+  /// Ensure copy was replaced
+  // CHECK-NOT:  memref.copy
+  // CHECK:      %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:      %[[C0_0:.*]] = arith.constant 0 : index
+  // CHECK:      %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]] : memref<1x1xf32>
+  // CHECK:      memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C0_0]]] : memref<1x108xf32>
+  memref.copy %src, %reinterpret_cast
+    : memref<1x1xf32> to memref<1x1xf32>
+  return
+}
+
+// CHECK-LABEL: func.func private @concat_nonzero_offset(
+// CHECK-SAME:   %[[SRC:.*]]: memref<1x1xf32>
+// CHECK-SAME:   %[[DST:.*]]: memref<1x108xf32>
+func.func private @concat_nonzero_offset(%src : memref<1x1xf32>,
+  %dst : memref<1x108xf32>) {
+  // CHECK-NOT:  memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %dst
+    to offset: [1], sizes: [1, 1], strides: [1, 1]
+    : memref<1x108xf32>
+      to memref<1x1xf32, strided<[1, 1], offset: 1>>
+
+  // CHECK-NOT:  memref.copy
+  // CHECK:      %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:      %[[C1:.*]] = arith.constant 1 : index
+  // CHECK:      %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]] : memref<1x1xf32>
+  // CHECK:      memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C1]]] : memref<1x108xf32>
+  memref.copy %src, %reinterpret_cast
+    : memref<1x1xf32>
+      to memref<1x1xf32, strided<[1, 1], offset: 1>>
+  return
+}
+
+// CHECK-LABEL: func.func private @concat_dynamic_offset(
+// CHECK-SAME:   %[[OFF:.*]]: index
+// CHECK-SAME:   %[[SRC:.*]]: memref<1x1xf32>
+// CHECK-SAME:   %[[DST:.*]]: memref<1x108xf32>
+func.func private @concat_dynamic_offset(%offset: index, %src : memref<1x1xf32>,
+  %dst : memref<1x108xf32>) {
+  // CHECK-NOT:  memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %dst
+    to offset: [%offset], sizes: [1, 1], strides: [1, 1]
+    : memref<1x108xf32>
+      to memref<1x1xf32, strided<[1, 1], offset: ?>>
+
+  // CHECK-NOT:  memref.copy
+  // CHECK:      %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:      %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]]
+  // CHECK-SAME: : memref<1x1xf32>
+  /// Dynamic offset used in store
+  // CHECK:      memref.store %[[VAL]], %[[DST]][%[[C0]], %[[OFF]]] : memref<1x108xf32>
+  memref.copy %src, %reinterpret_cast
+    : memref<1x1xf32>
+      to memref<1x1xf32, strided<[1, 1], offset: ?>>
+  return
+}
+
+// CHECK-LABEL: func.func private @concat_strided(
+// CHECK-SAME:   %[[SRC:.*]]: memref<1x1xf32>
+// CHECK-SAME:   %[[DST:.*]]: memref<1x108xf32>
+func.func private @concat_strided(%src : memref<1x1xf32>,
+  %dst : memref<1x108xf32>) {
+  // CHECK-NOT:  memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %dst
+    to offset: [0], sizes: [1, 1], strides: [107, 2]
+    : memref<1x108xf32> to memref<1x1xf32, strided<[107, 2]>>
+
+  // CHECK-NOT:  memref.copy
+  // CHECK:      %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:      %[[C0_0:.*]] = arith.constant 0 : index
+  // CHECK:      %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]] : memref<1x1xf32>
+  // CHECK:      memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C0_0]]] : memref<1x108xf32>
+  memref.copy %src, %reinterpret_cast
+    : memref<1x1xf32> to memref<1x1xf32, strided<[107, 2]>>
+  return
+}
+
+// CHECK-LABEL: func.func private @concat_dynamic_stride(
+// CHECK-SAME:   %[[STR0:[A-Za-z][A-Za-z0-9-]*]]: index
+// CHECK-SAME:   %[[STR1:[A-Za-z][A-Za-z0-9-]*]]: index
+// CHECK-SAME:   %[[SRC:[A-Za-z][A-Za-z0-9-]*]]: memref<1x1xf32>
+// CHECK-SAME:   %[[DST:[A-Za-z][A-Za-z0-9-]*]]: memref<1x108xf32>
+func.func private @concat_dynamic_stride(%stride0: index,
+  %stride1: index, %src : memref<1x1xf32>, %dst : memref<1x108xf32>) {
+  // CHECK-NOT:  memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %dst
+    to offset: [0], sizes: [1, 1], strides: [%stride0, %stride1]
+    : memref<1x108xf32>
+      to memref<1x1xf32, strided<[?, ?]>>
+
+  // CHECK-NOT:  memref.copy
+  // CHECK:      %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:      %[[C0_0:.*]] = arith.constant 0 : index
+  // CHECK:      %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]] : memref<1x1xf32>
+  /// Dynamic offset used in store
+  // CHECK:      memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C0_0]]] : memref<1x108xf32>
+  memref.copy %src, %reinterpret_cast
+    : memref<1x1xf32>
+      to memref<1x1xf32, strided<[?, ?]>>
+  return
+}
+
+// CHECK-LABEL: func.func private @concat_rank1(
+// CHECK-SAME:   %[[SRC:.*]]: memref<1xf32>
+// CHECK-SAME:   %[[DST:.*]]: memref<108xf32>
+func.func private @concat_rank1(%src : memref<1xf32>, %dst : memref<108xf32>) {
+  // CHECK-NOT:  memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %dst
+    to offset: [0], sizes: [1], strides: [1]
+    : memref<108xf32> to memref<1xf32>
+
+  // CHECK-NOT:  memref.copy
+  // CHECK:      %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:      %[[C0_0:.*]] = arith.constant 0 : index
+  // CHECK:      %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]]] : memref<1xf32>
+  // CHECK:      memref.store %[[VAL]], %[[DST]][%[[C0_0]]] : memref<108xf32>
+  memref.copy %src, %reinterpret_cast
+    : memref<1xf32> to memref<1xf32>
+  return
+}
+
+// CHECK-LABEL: func.func private @concat_rank3(
+// CHECK-SAME:   %[[SRC:.*]]: memref<1x1x1xf32>
+// CHECK-SAME:   %[[DST:.*]]: memref<1x1x108xf32>
+func.func private @concat_rank3(%src : memref<1x1x1xf32>,
+  %dst : memref<1x1x108xf32>) {
+  // CHECK-NOT:  memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %dst
+    to offset: [0], sizes: [1, 1, 1], strides: [1, 1, 1]
+    : memref<1x1x108xf32> to memref<1x1x1xf32>
+
+  // CHECK-NOT:  memref.copy
+  // CHECK:      %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:      %[[C0_0:.*]] = arith.constant 0 : index
+  // CHECK:      %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]], %[[C0]]] : memref<1x1x1xf32>
+  // CHECK:      memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C0]], %[[C0_0]]] : memref<1x1x108xf32>
+  memref.copy %src, %reinterpret_cast
+    : memref<1x1x1xf32> to memref<1x1x1xf32>
+  return
+}
+
+//===----------------------------------------------------------------------===//
+// Negative tests (must NOT rewrite)
----------------
banach-space wrote:

You can leave the comment, I find it quite helpful.

https://github.com/llvm/llvm-project/pull/186118


More information about the Mlir-commits mailing list