[Mlir-commits] [mlir] [memref] Simplify loads from reinterpret_cast of 1D contiguous memrefs (PR #188459)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Thu Apr 23 01:42:22 PDT 2026
================
@@ -220,3 +221,301 @@ func.func private @negative_plain_copy(%src : memref<1x1xf32>,
: memref<1x1xf32> to memref<1x1xf32>
return
}
+
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Positive tests
+//===----------------------------------------------------------------------===//
+
+/// For rank-1 MemRefs, expansion/collapsing may be considered on either side.
+
+// CHECK-LABEL: func.func private @expand_scalar(
+// CHECK-SAME: %[[SRC:.*]]: memref<1xi64>) {
+func.func private @expand_scalar(%src : memref<1xi64>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [1, 1, 1], strides: [1, 1, 1] : memref<1xi64>
+ to memref<1x1x1xi64>
+ // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]]] : memref<1xi64>
+ %0 = memref.load %reinterpret_cast[%c0, %c0, %c0] : memref<1x1x1xi64>
+ return
+}
+
+// CHECK-LABEL: func.func private @collapse_scalar(
+// CHECK-SAME: %[[SRC:.*]]: memref<1x1x1xi64>) {
+func.func private @collapse_scalar(%src : memref<1x1x1xi64>) {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [1, 1], strides: [1, 1] : memref<1x1x1xi64>
+ to memref<1x1xi64>
+ // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0_0]], %[[C0]], %[[C0]]] : memref<1x1x1xi64>
+ %0 = memref.load %reinterpret_cast[%c0, %c0] : memref<1x1xi64>
+ return
+}
+
+// CHECK-LABEL: func.func private @expand_left_vector(
+// CHECK-SAME: %[[SRC:.*]]: memref<999xi64>) {
+func.func private @expand_left_vector(%src : memref<999xi64>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [1, 1, 999], strides: [999, 999, 1]
+ : memref<999xi64> to memref<1x1x999xi64>
+ // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]]] : memref<999xi64>
+ %0 = memref.load %reinterpret_cast[%c0, %c0, %c0] : memref<1x1x999xi64>
+ return
+}
+
+// CHECK-LABEL: func.func private @expand_left_vector_dynamic_index(
+// CHECK-SAME: %[[I:.*]]: index
+// CHECK-SAME: %[[SRC:.*]]: memref<999xi64>) {
+func.func private @expand_left_vector_dynamic_index(%i : index,
+ %src : memref<999xi64>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [1, 1, 999], strides: [999, 999, 1]
+ : memref<999xi64> to memref<1x1x999xi64>
+ // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[I]]] : memref<999xi64>
+ %0 = memref.load %reinterpret_cast[%c0, %c0, %i] : memref<1x1x999xi64>
+ return
+}
+
+// CHECK-LABEL: func.func private @collapse_left_vector(
+// CHECK-SAME: %[[SRC:.*]]: memref<1x1x999xi64>) {
+func.func private @collapse_left_vector(%src : memref<1x1x999xi64>) {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ %c1 = arith.constant 1 : index
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [999], strides: [1]
+ : memref<1x1x999xi64> to memref<999xi64>
+ // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]], %[[C1]]] : memref<1x1x999xi64>
+ %0 = memref.load %reinterpret_cast[%c1] : memref<999xi64>
+ return
+}
+
+// CHECK-LABEL: func.func private @partial_expand_left_vector(
+// CHECK-SAME: %[[SRC:.*]]: memref<1x999xf32>) {
+func.func private @partial_expand_left_vector(
+ %src : memref<1x999xf32>) {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [1, 1, 999], strides: [999, 999, 1]
+ : memref<1x999xf32> to memref<1x1x999xf32>
----------------
banach-space wrote:
[nit] Here and in other places, the extra indentations is not needed IMHO.
```suggestion
: memref<1x999xf32> to memref<1x1x999xf32>
```
https://github.com/llvm/llvm-project/pull/188459
More information about the Mlir-commits
mailing list