[Mlir-commits] [mlir] [memref] Simplify loads from reinterpret_cast of 1D contiguous memrefs (PR #188459)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Apr 23 01:42:22 PDT 2026


================
@@ -220,3 +221,301 @@ func.func private @negative_plain_copy(%src : memref<1x1xf32>,
   : memref<1x1xf32> to memref<1x1xf32>
   return
 }
+
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Positive tests
+//===----------------------------------------------------------------------===//
+
+/// For rank-1 MemRefs, expansion/collapsing may be considered on either side.
+
+// CHECK-LABEL: func.func private @expand_scalar(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1xi64>) {
+func.func private @expand_scalar(%src : memref<1xi64>) {
+  // CHECK:       %[[C0:.*]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 1, 1], strides: [1, 1, 1] : memref<1xi64>
+      to memref<1x1x1xi64>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]]] : memref<1xi64>
+  %0 = memref.load %reinterpret_cast[%c0, %c0, %c0] : memref<1x1x1xi64>
+  return
+}
+
+// CHECK-LABEL: func.func private @collapse_scalar(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1x1x1xi64>) {
+func.func private @collapse_scalar(%src : memref<1x1x1xi64>) {
+  // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C0_0:.*]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 1], strides: [1, 1] : memref<1x1x1xi64>
+      to memref<1x1xi64>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0_0]], %[[C0]], %[[C0]]] : memref<1x1x1xi64>
+  %0 = memref.load %reinterpret_cast[%c0, %c0] : memref<1x1xi64>
+  return
+}
+
+// CHECK-LABEL: func.func private @expand_left_vector(
+// CHECK-SAME:    %[[SRC:.*]]: memref<999xi64>) {
+func.func private @expand_left_vector(%src : memref<999xi64>) {
+  // CHECK:       %[[C0:.*]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 1, 999], strides: [999, 999, 1]
+      : memref<999xi64> to memref<1x1x999xi64>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]]] : memref<999xi64>
+  %0 = memref.load %reinterpret_cast[%c0, %c0, %c0] : memref<1x1x999xi64>
+  return
+}
+
+// CHECK-LABEL: func.func private @expand_left_vector_dynamic_index(
+// CHECK-SAME:    %[[I:.*]]: index
+// CHECK-SAME:    %[[SRC:.*]]: memref<999xi64>) {
+func.func private @expand_left_vector_dynamic_index(%i : index,
+    %src : memref<999xi64>) {
+  // CHECK:       %[[C0:.*]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 1, 999], strides: [999, 999, 1]
+      : memref<999xi64> to memref<1x1x999xi64>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[I]]] : memref<999xi64>
+  %0 = memref.load %reinterpret_cast[%c0, %c0, %i] : memref<1x1x999xi64>
+  return
+}
+
+// CHECK-LABEL: func.func private @collapse_left_vector(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1x1x999xi64>) {
+func.func private @collapse_left_vector(%src : memref<1x1x999xi64>) {
+  // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+  %c1 = arith.constant 1 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [999], strides: [1]
+      : memref<1x1x999xi64> to memref<999xi64>
+  // CHECK:       %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]], %[[C1]]] : memref<1x1x999xi64>
+  %0 = memref.load %reinterpret_cast[%c1] : memref<999xi64>
+  return
+}
+
+// CHECK-LABEL: func.func private @partial_expand_left_vector(
+// CHECK-SAME:    %[[SRC:.*]]: memref<1x999xf32>) {
+func.func private @partial_expand_left_vector(
+    %src : memref<1x999xf32>) {
+  // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK-NOT:   memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %src
+    to offset: [0], sizes: [1, 1, 999], strides: [999, 999, 1]
+      : memref<1x999xf32> to memref<1x1x999xf32>
----------------
banach-space wrote:

[nit] Here and in other places, the extra indentations is not needed IMHO.
```suggestion
    : memref<1x999xf32> to memref<1x1x999xf32>
```

https://github.com/llvm/llvm-project/pull/188459


More information about the Mlir-commits mailing list