[Mlir-commits] [mlir] [memref] Simplify loads from reinterpret_cast of 1D contiguous memrefs (PR #188459)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Thu Mar 26 09:18:08 PDT 2026
================
@@ -220,3 +221,309 @@ func.func private @negative_plain_copy(%src : memref<1x1xf32>,
: memref<1x1xf32> to memref<1x1xf32>
return
}
+
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Positive tests
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func private @reshape_expand_scalar(
+// CHECK-SAME: %[[SRC:.*]]: memref<1xi64>) {
+func.func private @reshape_expand_scalar(%src : memref<1xi64>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [1, 1, 1], strides: [1, 1, 1] : memref<1xi64>
+ to memref<1x1x1xi64>
+ // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]]] : memref<1xi64>
+ %0 = memref.load %reinterpret_cast[%c0, %c0, %c1] : memref<1x1x1xi64>
+ return
+}
+
+// CHECK-LABEL: func.func private @reshape_collapse_scalar(
+// CHECK-SAME: %[[SRC:.*]]: memref<1x1x1xi64>) {
+func.func private @reshape_collapse_scalar(%src : memref<1x1x1xi64>) {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [1, 1], strides: [1, 1] : memref<1x1x1xi64>
+ to memref<1x1xi64>
+ // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0_0]], %[[C0]], %[[C1]]] : memref<1x1x1xi64>
+ %0 = memref.load %reinterpret_cast[%c0, %c1] : memref<1x1xi64>
+ return
+}
+
+// CHECK-LABEL: func.func private @reshape_expand_left_vector(
+// CHECK-SAME: %[[SRC:.*]]: memref<999xi64>) {
+func.func private @reshape_expand_left_vector(%src : memref<999xi64>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [1, 1, 999], strides: [999, 999, 1]
+ : memref<999xi64> to memref<1x1x999xi64>
+ // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]]] : memref<999xi64>
+ %0 = memref.load %reinterpret_cast[%c0, %c0, %c0] : memref<1x1x999xi64>
+ return
+}
+
+// CHECK-LABEL: func.func private @reshape_collapse_left_vector(
+// CHECK-SAME: %[[SRC:.*]]: memref<1x1x999xi64>) {
+func.func private @reshape_collapse_left_vector(%src : memref<1x1x999xi64>) {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ %c1 = arith.constant 1 : index
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [999], strides: [1]
+ : memref<1x1x999xi64> to memref<999xi64>
+ // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]], %[[C1]]] : memref<1x1x999xi64>
+ %0 = memref.load %reinterpret_cast[%c1] : memref<999xi64>
+ return
+}
+
+// CHECK-LABEL: func.func private @reshape_expand_left_inner_unit_dims(
+// CHECK-SAME: %[[SRC:.*]]: memref<1x108xf32>) {
+func.func private @reshape_expand_left_inner_unit_dims(
+ %src : memref<1x108xf32>) {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [1, 1, 1, 108], strides: [108, 108, 108, 1]
+ : memref<1x108xf32> to memref<1x1x1x108xf32>
+ // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]], %[[C0]]] : memref<1x108xf32>
+ %0 = memref.load %reinterpret_cast[%c0, %c0, %c1, %c0]
+ : memref<1x1x1x108xf32>
+ return
+}
+
+// CHECK-LABEL: func.func private @reshape_collapse_left_inner_unit_dims(
+// CHECK-SAME: %[[SRC:.*]]: memref<1x1x1x100xf32>) {
+func.func private @reshape_collapse_left_inner_unit_dims(
+ %src : memref<1x1x1x100xf32>) {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [1, 100], strides: [100, 1]
+ : memref<1x1x1x100xf32> to memref<1x100xf32>
+ // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0_0]], %[[C0_0]], %[[C0]], %[[C1]]] : memref<1x1x1x100xf32>
+ %0 = memref.load %reinterpret_cast[%c0, %c1] : memref<1x100xf32>
+ return
+}
+
+// CHECK-LABEL: func.func private @reshape_expand_right_vector(
+// CHECK-SAME: %[[SRC:.*]]: memref<999xi64>) {
+func.func private @reshape_expand_right_vector(%src : memref<999xi64>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [999, 1, 1], strides: [1, 999, 999]
+ : memref<999xi64> to memref<999x1x1xi64, strided<[1, 999, 999]>>
+ // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]]] : memref<999xi64>
+ %0 = memref.load %reinterpret_cast[%c0, %c0, %c0] : memref<999x1x1xi64,
+ strided<[1, 999, 999]>>
+ return
+}
+
+// CHECK-LABEL: func.func private @reshape_collapse_right_vector(
+// CHECK-SAME: %[[SRC:.*]]: memref<999x1x1xi64>) {
+func.func private @reshape_collapse_right_vector(%src : memref<999x1x1xi64>) {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ %c1 = arith.constant 1 : index
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [999], strides: [1]
+ : memref<999x1x1xi64> to memref<999xi64>
+ // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]], %[[C0]], %[[C0]]] : memref<999x1x1xi64>
+ %0 = memref.load %reinterpret_cast[%c1] : memref<999xi64>
+ return
+}
+
+// CHECK-LABEL: func.func private @reshape_expand_right_inner_unit_dims(
+// CHECK-SAME: %[[SRC:.*]]: memref<108x1xf32>) {
+func.func private @reshape_expand_right_inner_unit_dims(
+ %src : memref<108x1xf32>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [108, 1, 1, 1], strides: [1, 108, 108, 108]
+ : memref<108x1xf32> to memref<108x1x1x1xf32, strided<[1, 108, 108, 108]>>
+ // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]], %[[C1]]] : memref<108x1xf32>
+ %0 = memref.load %reinterpret_cast[%c0, %c1, %c0, %c0]
+ : memref<108x1x1x1xf32, strided<[1, 108, 108, 108]>>
+ return
+}
+
+// CHECK-LABEL: func.func private @reshape_collapse_right_inner_unit_dims(
+// CHECK-SAME: %[[SRC:.*]]: memref<100x1x1x1xf32>) {
+func.func private @reshape_collapse_right_inner_unit_dims(
+ %src : memref<100x1x1x1xf32>) {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [100, 1], strides: [1, 100]
+ : memref<100x1x1x1xf32> to memref<100x1xf32, strided<[1, 100]>>
+ // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]], %[[C0]], %[[C0_0]], %[[C0_0]]] : memref<100x1x1x1xf32>
+ %0 = memref.load %reinterpret_cast[%c1, %c0] : memref<100x1xf32,
+ strided<[1, 100]>>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// Negative tests (must NOT rewrite)
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func private @negative_reshape_nonzero_offset(
+// CHECK-SAME: %[[SRC:.*]]: memref<1xi64>) {
+func.func private @negative_reshape_nonzero_offset(
+ %src : memref<1xi64>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [1], sizes: [1, 1, 1], strides: [1, 1, 1] : memref<1xi64> to memref<1x1x1xi64, strided<[1, 1, 1], offset: 1>>
----------------
banach-space wrote:
It suffices to check that `memref.reinterpet_cast` is still here and that that's what `memref.load` loads from. Similar comment for other tests.
```suggestion
// CHECK: %[[RC:.*]] = memref.reinterpret_cast
```
https://github.com/llvm/llvm-project/pull/188459
More information about the Mlir-commits
mailing list