[Mlir-commits] [mlir] [mlir] Implement memory-space cast operand fusion into consumers (PR #159454)
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Sep 19 03:28:16 PDT 2025
================
@@ -0,0 +1,281 @@
+// RUN: mlir-opt %s --fuse-memory-space-casts-into-consumers | FileCheck %s
+
+#map = affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>
+// CHECK-LABEL: func.func @load_store(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: index) {
+// CHECK: %[[VAL_0:.*]] = memref.load %[[ARG0]]{{\[}}%[[ARG1]]] : memref<?xf32, 1>
+// CHECK: memref.store %[[VAL_0]], %[[ARG0]]{{\[}}%[[ARG1]]] : memref<?xf32, 1>
+// CHECK: return
+// CHECK: }
+func.func @load_store(%arg0: memref<?xf32, 1>, %arg1: index) {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+ %0 = memref.load %memspacecast[%arg1] : memref<?xf32>
+ memref.store %0, %memspacecast[%arg1] : memref<?xf32>
+ return
+}
+
+// CHECK-LABEL: func.func @load_store_unfoldable(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: index) {
+// CHECK: %[[VAL_0:.*]] = memref.memory_space_cast %[[ARG0]] : memref<?xf32, 1> to memref<?xf32, 2>
+// CHECK: %[[VAL_1:.*]] = memref.load %[[VAL_0]]{{\[}}%[[ARG1]]] : memref<?xf32, 2>
+// CHECK: memref.store %[[VAL_1]], %[[VAL_0]]{{\[}}%[[ARG1]]] : memref<?xf32, 2>
+// CHECK: return
+// CHECK: }
+func.func @load_store_unfoldable(%arg0: memref<?xf32, 1>, %arg1: index) {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32, 2>
+ %0 = memref.load %memspacecast[%arg1] : memref<?xf32, 2>
+ memref.store %0, %memspacecast[%arg1] : memref<?xf32, 2>
+ return
+}
+
+// CHECK-LABEL: func.func @view(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xi8, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> memref<?x?xi8> {
+// CHECK: %[[VAL_0:.*]] = arith.constant 100 : index
+// CHECK: %[[VAL_1:.*]] = memref.view %[[ARG0]]{{\[}}%[[ARG1]]]{{\[}}%[[ARG2]], %[[VAL_0]]] : memref<?xi8, 1> to memref<?x?xi8, 1>
+// CHECK: %[[VAL_2:.*]] = memref.memory_space_cast %[[VAL_1]] : memref<?x?xi8, 1> to memref<?x?xi8>
+// CHECK: return %[[VAL_2]] : memref<?x?xi8>
+// CHECK: }
+func.func @view(%arg0: memref<?xi8, 1>, %arg1: index, %arg2: index) -> memref<?x?xi8> {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?xi8, 1> to memref<?xi8>
+ %c100 = arith.constant 100 : index
+ %view = memref.view %memspacecast[%arg1][%arg2, %c100] : memref<?xi8> to memref<?x?xi8>
+ return %view : memref<?x?xi8>
+}
+
+// CHECK-LABEL: func.func @subview(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: index) -> memref<8x2xf32, strided<[?, 2], offset: ?>> {
+// CHECK: %[[VAL_0:.*]] = memref.subview %[[ARG0]][4, 2] [8, 2] [3, 2] : memref<?x?xf32, 1> to memref<8x2xf32, strided<[?, 2], offset: ?>, 1>
+// CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<8x2xf32, strided<[?, 2], offset: ?>, 1> to memref<8x2xf32, strided<[?, 2], offset: ?>>
+// CHECK: return %[[VAL_1]] : memref<8x2xf32, strided<[?, 2], offset: ?>>
+// CHECK: }
+func.func @subview(%arg0: memref<?x?xf32, 1>, %arg1: index) -> memref<8x2xf32, strided<[?, 2], offset: ?>> {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?x?xf32, 1> to memref<?x?xf32>
+ %subview = memref.subview %memspacecast[4, 2] [8, 2] [3, 2] : memref<?x?xf32> to memref<8x2xf32, strided<[?, 2], offset: ?>>
+ return %subview : memref<8x2xf32, strided<[?, 2], offset: ?>>
+}
+
+// CHECK-LABEL: func.func @reinterpret_cast(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: index) -> memref<10x?xf32, strided<[?, 1], offset: ?>> {
+// CHECK: %[[VAL_0:.*]] = arith.constant 10 : index
+// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_2:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[VAL_1]]], sizes: [10, %[[VAL_0]]], strides: {{\[}}%[[VAL_0]], 1] : memref<?xf32, 1> to memref<10x?xf32, strided<[?, 1], offset: ?>, 1>
+// CHECK: %[[VAL_3:.*]] = memref.memory_space_cast %[[VAL_2]] : memref<10x?xf32, strided<[?, 1], offset: ?>, 1> to memref<10x?xf32, strided<[?, 1], offset: ?>>
+// CHECK: return %[[VAL_3]] : memref<10x?xf32, strided<[?, 1], offset: ?>>
+// CHECK: }
+func.func @reinterpret_cast(%arg0: memref<?xf32, 1>, %arg1: index) -> memref<10x?xf32, strided<[?, 1], offset: ?>> {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+ %c0 = arith.constant 0 : index
+ %c10 = arith.constant 10 : index
+ %reinterpret_cast = memref.reinterpret_cast %memspacecast to offset: [%c0], sizes: [10, %c10], strides: [%c10, 1] : memref<?xf32> to memref<10x?xf32, strided<[?, 1], offset: ?>>
+ return %reinterpret_cast : memref<10x?xf32, strided<[?, 1], offset: ?>>
+}
+
+// CHECK-LABEL: func.func @reshape(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: memref<1xindex>) -> memref<?xf32> {
+// CHECK: %[[VAL_0:.*]] = memref.reshape %[[ARG0]](%[[ARG1]]) : (memref<?x?xf32, 1>, memref<1xindex>) -> memref<?xf32, 1>
+// CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<?xf32, 1> to memref<?xf32>
+// CHECK: return %[[VAL_1]] : memref<?xf32>
+// CHECK: }
+func.func @reshape(%arg0: memref<?x?xf32, 1>, %arg1: memref<1xindex>) -> memref<?xf32> {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?x?xf32, 1> to memref<?x?xf32>
+ %reshape = memref.reshape %memspacecast(%arg1) : (memref<?x?xf32>, memref<1xindex>) -> memref<?xf32>
+ return %reshape : memref<?xf32>
+}
+
+// CHECK-LABEL: func.func @expand_shape(
+// CHECK-SAME: %[[ARG0:.*]]: memref<12xf32, 1>) -> memref<3x4xf32> {
+// CHECK: %[[VAL_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1]] output_shape [3, 4] : memref<12xf32, 1> into memref<3x4xf32, 1>
+// CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<3x4xf32, 1> to memref<3x4xf32>
+// CHECK: return %[[VAL_1]] : memref<3x4xf32>
+// CHECK: }
+func.func @expand_shape(%arg0: memref<12xf32, 1>) -> memref<3x4xf32> {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<12xf32, 1> to memref<12xf32>
+ %expand_shape = memref.expand_shape %memspacecast [[0, 1]] output_shape [3, 4] : memref<12xf32> into memref<3x4xf32>
+ return %expand_shape : memref<3x4xf32>
+}
+
+// CHECK-LABEL: func.func @collapse_shape(
+// CHECK-SAME: %[[ARG0:.*]]: memref<3x4xf32, 1>) -> memref<12xf32> {
+// CHECK: %[[VAL_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[\[}}0, 1]] : memref<3x4xf32, 1> into memref<12xf32, 1>
+// CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<12xf32, 1> to memref<12xf32>
+// CHECK: return %[[VAL_1]] : memref<12xf32>
+// CHECK: }
+func.func @collapse_shape(%arg0: memref<3x4xf32, 1>) -> memref<12xf32> {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<3x4xf32, 1> to memref<3x4xf32>
+ %collapse_shape = memref.collapse_shape %memspacecast [[0, 1]] : memref<3x4xf32> into memref<12xf32>
+ return %collapse_shape : memref<12xf32>
+}
+
+// CHECK-LABEL: func.func @transpose(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32, 1>) -> memref<?x?xf32, #[[$ATTR_0]]> {
+// CHECK: %[[VAL_0:.*]] = memref.transpose %[[ARG0]] (d0, d1) -> (d1, d0) : memref<?x?xf32, 1> to memref<?x?xf32, #[[$ATTR_0]], 1>
+// CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<?x?xf32, #[[$ATTR_0]], 1> to memref<?x?xf32, #[[$ATTR_0]]>
+// CHECK: return %[[VAL_1]] : memref<?x?xf32, #[[$ATTR_0]]>
+// CHECK: }
+func.func @transpose(%arg0: memref<?x?xf32, 1>) -> memref<?x?xf32, #map> {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?x?xf32, 1> to memref<?x?xf32>
+ %transpose = memref.transpose %memspacecast (d0, d1) -> (d1, d0) : memref<?x?xf32> to memref<?x?xf32, #map>
+ return %transpose : memref<?x?xf32, #map>
+}
+
+// CHECK-LABEL: func.func @atomic_rmw(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: index,
+// CHECK-SAME: %[[ARG2:.*]]: f32) -> f32 {
+// CHECK: %[[VAL_0:.*]] = memref.atomic_rmw addf %[[ARG2]], %[[ARG0]]{{\[}}%[[ARG1]]] : (f32, memref<?xf32, 1>) -> f32
+// CHECK: return %[[VAL_0]] : f32
+// CHECK: }
+func.func @atomic_rmw(%arg0: memref<?xf32, 1>, %arg1: index, %arg2: f32) -> f32 {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+ %0 = memref.atomic_rmw addf %arg2, %memspacecast[%arg1] : (f32, memref<?xf32>) -> f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func.func @assume_alignment(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>) -> memref<?xf32> {
+// CHECK: %[[VAL_0:.*]] = memref.assume_alignment %[[ARG0]], 16 : memref<?xf32, 1>
+// CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<?xf32, 1> to memref<?xf32>
+// CHECK: return %[[VAL_1]] : memref<?xf32>
+// CHECK: }
+func.func @assume_alignment(%arg0: memref<?xf32, 1>) -> memref<?xf32> {
+ %memspacecast = memref.memory_space_cast %arg0 : memref<?xf32, 1> to memref<?xf32>
+ %1 = memref.assume_alignment %memspacecast, 16 : memref<?xf32>
+ return %1 : memref<?xf32>
+}
+
+// CHECK-LABEL: func.func @op_with_cast_sequence(
+// CHECK-SAME: %[[ARG0:.*]]: memref<4x4xf32, 1>,
+// CHECK-SAME: %[[ARG1:.*]]: index,
+// CHECK-SAME: %[[ARG2:.*]]: f32) -> memref<16xf32> {
+// CHECK: %[[VAL_0:.*]] = arith.constant 4 : index
+// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_2:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32, 1> into memref<4x2x2xf32, 1>
+// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0, 1, 2]] : memref<4x2x2xf32, 1> into memref<16xf32, 1>
+// CHECK: %[[VAL_4:.*]] = memref.memory_space_cast %[[VAL_3]] : memref<16xf32, 1> to memref<16xf32>
----------------
nicolasvasilache wrote:
the position of this op feels weird, haven't checked the impl yet but you may want to use CHECK-DAG to avoid surprises
https://github.com/llvm/llvm-project/pull/159454
More information about the Mlir-commits
mailing list