[Mlir-commits] [mlir] [memref] Simplify loads from reinterpret_cast of 1D contiguous memrefs (PR #188459)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Mar 26 09:18:08 PDT 2026


================
@@ -220,3 +221,309 @@ func.func private @negative_plain_copy(%src : memref<1x1xf32>,
   : memref<1x1xf32> to memref<1x1xf32>
   return
 }
+
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Positive tests
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func private @reshape_expand_scalar(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1xi64>) {
+func.func private @reshape_expand_scalar(%src : memref<1xi64>) {
+  // CHECK:       %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:       %[[C1:.*]] = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 1, 1], strides: [1, 1, 1] : memref<1xi64>
+      to memref<1x1x1xi64>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]]] : memref<1xi64>
+  %0 = memref.load %reinterpret_cast[%c0, %c0, %c1] : memref<1x1x1xi64>
+  return
+}
+
+// CHECK-LABEL: func.func private @reshape_collapse_scalar(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1x1x1xi64>) {
+func.func private @reshape_collapse_scalar(%src : memref<1x1x1xi64>) {
+  // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C0_0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 1], strides: [1, 1] : memref<1x1x1xi64>
+      to memref<1x1xi64>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0_0]], %[[C0]], %[[C1]]] : memref<1x1x1xi64>
+  %0 = memref.load %reinterpret_cast[%c0, %c1] : memref<1x1xi64>
+  return
+}
+
+// CHECK-LABEL: func.func private @reshape_expand_left_vector(
+// CHECK-SAME:    %[[SRC:.*]]: memref<999xi64>) {
+func.func private @reshape_expand_left_vector(%src : memref<999xi64>) {
+  // CHECK:       %[[C0:.*]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 1, 999], strides: [999, 999, 1]
+      : memref<999xi64> to memref<1x1x999xi64>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]]] : memref<999xi64>
+  %0 = memref.load %reinterpret_cast[%c0, %c0, %c0] : memref<1x1x999xi64>
+  return
+}
+
+// CHECK-LABEL: func.func private @reshape_collapse_left_vector(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1x1x999xi64>) {
+func.func private @reshape_collapse_left_vector(%src : memref<1x1x999xi64>) {
+  // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+  %c1 = arith.constant 1 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [999], strides: [1]
+      : memref<1x1x999xi64> to memref<999xi64>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]], %[[C1]]] : memref<1x1x999xi64>
+  %0 = memref.load %reinterpret_cast[%c1] : memref<999xi64>
+  return
+}
+
+// CHECK-LABEL: func.func private @reshape_expand_left_inner_unit_dims(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1x108xf32>) {
+func.func private @reshape_expand_left_inner_unit_dims(
+    %src : memref<1x108xf32>) {
+  // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 1, 1, 108], strides: [108, 108, 108, 1]
+      : memref<1x108xf32> to memref<1x1x1x108xf32>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]], %[[C0]]] : memref<1x108xf32>
+  %0 = memref.load %reinterpret_cast[%c0, %c0, %c1, %c0]
+    : memref<1x1x1x108xf32>
+  return
+}
+
+// CHECK-LABEL: func.func private @reshape_collapse_left_inner_unit_dims(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1x1x1x100xf32>) {
+func.func private @reshape_collapse_left_inner_unit_dims(
+    %src : memref<1x1x1x100xf32>) {
+  // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C0_0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 100], strides: [100, 1]
+      : memref<1x1x1x100xf32> to memref<1x100xf32>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0_0]], %[[C0_0]], %[[C0]], %[[C1]]] : memref<1x1x1x100xf32>
+  %0 = memref.load %reinterpret_cast[%c0, %c1] : memref<1x100xf32>
+  return
+}
+
+// CHECK-LABEL: func.func private @reshape_expand_right_vector(
+// CHECK-SAME:    %[[SRC:.*]]: memref<999xi64>) {
+func.func private @reshape_expand_right_vector(%src : memref<999xi64>) {
+  // CHECK:       %[[C0:.*]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [999, 1, 1], strides: [1, 999, 999]
+      : memref<999xi64> to memref<999x1x1xi64, strided<[1, 999, 999]>>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]]] : memref<999xi64>
+  %0 = memref.load %reinterpret_cast[%c0, %c0, %c0] : memref<999x1x1xi64,
+    strided<[1, 999, 999]>>
+  return
+}
+
+// CHECK-LABEL: func.func private @reshape_collapse_right_vector(
+// CHECK-SAME:    %[[SRC:.*]]: memref<999x1x1xi64>) {
+func.func private @reshape_collapse_right_vector(%src : memref<999x1x1xi64>) {
+  // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+  %c1 = arith.constant 1 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [999], strides: [1]
+      : memref<999x1x1xi64> to memref<999xi64>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]], %[[C0]], %[[C0]]] : memref<999x1x1xi64>
+  %0 = memref.load %reinterpret_cast[%c1] : memref<999xi64>
+  return
+}
+
+// CHECK-LABEL: func.func private @reshape_expand_right_inner_unit_dims(
+// CHECK-SAME:    %[[SRC:.*]]: memref<108x1xf32>) {
+func.func private @reshape_expand_right_inner_unit_dims(
+    %src : memref<108x1xf32>) {
+  // CHECK:       %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:       %[[C1:.*]] = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [108, 1, 1, 1], strides: [1, 108, 108, 108]
+      : memref<108x1xf32> to memref<108x1x1x1xf32, strided<[1, 108, 108, 108]>>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]], %[[C1]]] : memref<108x1xf32>
+  %0 = memref.load %reinterpret_cast[%c0, %c1, %c0, %c0]
+    : memref<108x1x1x1xf32, strided<[1, 108, 108, 108]>>
+  return
+}
+
+// CHECK-LABEL: func.func private @reshape_collapse_right_inner_unit_dims(
+// CHECK-SAME:    %[[SRC:.*]]: memref<100x1x1x1xf32>) {
+func.func private @reshape_collapse_right_inner_unit_dims(
+    %src : memref<100x1x1x1xf32>) {
+  // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C0_0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [100, 1], strides: [1, 100]
+      : memref<100x1x1x1xf32> to memref<100x1xf32, strided<[1, 100]>>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]], %[[C0]], %[[C0_0]], %[[C0_0]]] : memref<100x1x1x1xf32>
+  %0 = memref.load %reinterpret_cast[%c1, %c0] : memref<100x1xf32,
+    strided<[1, 100]>>
+  return
+}
+
+//===----------------------------------------------------------------------===//
+// Negative tests (must NOT rewrite)
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func private @negative_reshape_nonzero_offset(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1xi64>) {
+func.func private @negative_reshape_nonzero_offset(
+    %src : memref<1xi64>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [1], sizes: [1, 1, 1], strides: [1, 1, 1] : memref<1xi64> to memref<1x1x1xi64, strided<[1, 1, 1], offset: 1>>
----------------
banach-space wrote:

It suffices to check that `memref.reinterpet_cast` is still here and that that's what `memref.load` loads from. Similar comment for other tests.

```suggestion
  // CHECK:       %[[RC:.*]] = memref.reinterpret_cast
```

https://github.com/llvm/llvm-project/pull/188459


More information about the Mlir-commits mailing list