[Mlir-commits] [mlir] [mlir][arith] Add memref alias folders for expand/collapse_shape for vector load/store (PR #95223)

Kunwar Grover llvmlistbot at llvm.org
Wed Jun 12 04:12:19 PDT 2024


https://github.com/Groverkss created https://github.com/llvm/llvm-project/pull/95223

This patch adds adds patterns to fold memref alias for expand_shape/collapse_shape feeding into vector.load/vector.store and vector.maskedload/vector.maskedstore

>From c41437beef1b5b6c04e63c6d673922ecd2c6f967 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Tue, 11 Jun 2024 18:09:26 +0000
Subject: [PATCH 1/2] [mlir][Memref] Add folders for expand/collapse_shape for
 vector load/store

---
 .../MemRef/Transforms/FoldMemRefAliasOps.cpp  | 90 ++++++++++++++++---
 1 file changed, 79 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index db085b386483c..96daf4c5972a4 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -518,10 +518,25 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
           loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
     return failure();
   llvm::TypeSwitch<Operation *, void>(loadOp)
-      .Case<affine::AffineLoadOp, memref::LoadOp>([&](auto op) {
-        rewriter.replaceOpWithNewOp<decltype(op)>(
+      .Case([&](affine::AffineLoadOp op) {
+        rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
             loadOp, expandShapeOp.getViewSource(), sourceIndices);
       })
+      .Case([&](memref::LoadOp op) {
+        rewriter.replaceOpWithNewOp<memref::LoadOp>(
+            loadOp, expandShapeOp.getViewSource(), sourceIndices,
+            op.getNontemporal());
+      })
+      .Case([&](vector::LoadOp op) {
+        rewriter.replaceOpWithNewOp<vector::LoadOp>(
+            op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
+            op.getNontemporal());
+      })
+      .Case([&](vector::MaskedLoadOp op) {
+        rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
+            op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
+            op.getMask(), op.getPassThru());
+      })
       .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
   return success();
 }
@@ -551,10 +566,25 @@ LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
           loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
     return failure();
   llvm::TypeSwitch<Operation *, void>(loadOp)
-      .Case<affine::AffineLoadOp, memref::LoadOp>([&](auto op) {
-        rewriter.replaceOpWithNewOp<decltype(op)>(
+      .Case([&](affine::AffineLoadOp op) {
+        rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
             loadOp, collapseShapeOp.getViewSource(), sourceIndices);
       })
+      .Case([&](memref::LoadOp op) {
+        rewriter.replaceOpWithNewOp<memref::LoadOp>(
+            loadOp, collapseShapeOp.getViewSource(), sourceIndices,
+            op.getNontemporal());
+      })
+      .Case([&](vector::LoadOp op) {
+        rewriter.replaceOpWithNewOp<vector::LoadOp>(
+            op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
+            op.getNontemporal());
+      })
+      .Case([&](vector::MaskedLoadOp op) {
+        rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
+            op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
+            op.getMask(), op.getPassThru());
+      })
       .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
   return success();
 }
@@ -651,10 +681,25 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
           storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
     return failure();
   llvm::TypeSwitch<Operation *, void>(storeOp)
-      .Case<affine::AffineStoreOp, memref::StoreOp>([&](auto op) {
-        rewriter.replaceOpWithNewOp<decltype(op)>(storeOp, storeOp.getValue(),
-                                                  expandShapeOp.getViewSource(),
-                                                  sourceIndices);
+      .Case([&](affine::AffineStoreOp op) {
+        rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
+            storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
+            sourceIndices);
+      })
+      .Case([&](memref::StoreOp op) {
+        rewriter.replaceOpWithNewOp<memref::StoreOp>(
+            storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
+            sourceIndices, op.getNontemporal());
+      })
+      .Case([&](vector::StoreOp op) {
+        rewriter.replaceOpWithNewOp<vector::StoreOp>(
+            op, op.getValueToStore(), expandShapeOp.getViewSource(),
+            sourceIndices, op.getNontemporal());
+      })
+      .Case([&](vector::MaskedStoreOp op) {
+        rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
+            op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
+            op.getValueToStore());
       })
       .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
   return success();
@@ -685,11 +730,26 @@ LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
           storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
     return failure();
   llvm::TypeSwitch<Operation *, void>(storeOp)
-      .Case<affine::AffineStoreOp, memref::StoreOp>([&](auto op) {
-        rewriter.replaceOpWithNewOp<decltype(op)>(
-            storeOp, storeOp.getValue(), collapseShapeOp.getViewSource(),
+      .Case([&](affine::AffineStoreOp op) {
+        rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
+            storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
             sourceIndices);
       })
+      .Case([&](memref::StoreOp op) {
+        rewriter.replaceOpWithNewOp<memref::StoreOp>(
+            storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
+            sourceIndices, op.getNontemporal());
+      })
+      .Case([&](vector::StoreOp op) {
+        rewriter.replaceOpWithNewOp<vector::StoreOp>(
+            op, op.getValueToStore(), collapseShapeOp.getViewSource(),
+            sourceIndices, op.getNontemporal());
+      })
+      .Case([&](vector::MaskedStoreOp op) {
+        rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
+            op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
+            op.getValueToStore());
+      })
       .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
   return success();
 }
@@ -763,12 +823,20 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
                StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
                LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>,
                LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
+               LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
+               LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
                StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>,
                StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
+               StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
+               StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
                LoadOpOfCollapseShapeOpFolder<affine::AffineLoadOp>,
                LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
+               LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
+               LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
                StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>,
                StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
+               StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
+               StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
                SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
       patterns.getContext());
 }

>From 08bfbbe2ac7e1f449ea778601c0042fee650c74d Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Wed, 12 Jun 2024 11:09:53 +0000
Subject: [PATCH 2/2] Add tests

---
 .../Dialect/MemRef/fold-memref-alias-ops.mlir | 164 +++++++++++++++++-
 1 file changed, 156 insertions(+), 8 deletions(-)

diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index e49dff44ae0d6..d67d6df23f90b 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -819,14 +819,14 @@ func.func @test_ldmatrix(%arg0: memref<4x32x32xf16, 3>, %arg1: index, %arg2: ind
 
 // -----
 
-func.func @fold_vector_load(
+func.func @fold_vector_load_subview(
   %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index) -> vector<12x32xf32> {
   %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
   %1 = vector.load %0[] : memref<f32, strided<[], offset: ?>>, vector<12x32xf32>
   return %1 : vector<12x32xf32>
 }
 
-//      CHECK: func @fold_vector_load
+//      CHECK: func @fold_vector_load_subview
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
@@ -834,14 +834,14 @@ func.func @fold_vector_load(
 
 // -----
 
-func.func @fold_vector_maskedload(
+func.func @fold_vector_maskedload_subview(
   %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> vector<32xf32> {
   %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
   %1 = vector.maskedload %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32> into vector<32xf32>
   return %1 : vector<32xf32>
 }
 
-//      CHECK: func @fold_vector_maskedload
+//      CHECK: func @fold_vector_maskedload_subview
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
@@ -851,14 +851,14 @@ func.func @fold_vector_maskedload(
 
 // -----
 
-func.func @fold_vector_store(
+func.func @fold_vector_store_subview(
   %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<2x32xf32>) -> () {
   %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
   vector.store %arg3, %0[] : memref<f32, strided<[], offset: ?>>, vector<2x32xf32>
   return
 }
 
-//      CHECK: func @fold_vector_store
+//      CHECK: func @fold_vector_store_subview
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
@@ -868,14 +868,14 @@ func.func @fold_vector_store(
 
 // -----
 
-func.func @fold_vector_maskedstore(
+func.func @fold_vector_maskedstore_subview(
   %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> () {
   %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
   vector.maskedstore %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32>
   return
 }
 
-//      CHECK: func @fold_vector_maskedstore
+//      CHECK: func @fold_vector_maskedstore_subview
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
@@ -883,3 +883,151 @@ func.func @fold_vector_maskedstore(
 // CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32>
 //      CHECK:   vector.maskedstore %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32>
 //      CHECK:   return
+
+// -----
+
+func.func @fold_vector_load_expand_shape(
+  %arg0 : memref<32xf32>, %arg1 : index) -> vector<8xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+  %1 = vector.load %0[%arg1, %c0] : memref<4x8xf32>, vector<8xf32>
+  return %1 : vector<8xf32>
+}
+
+//   CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
+// CHECK-LABEL: func @fold_vector_load_expand_shape
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//       CHECK:   %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
+//       CHECK:   vector.load %[[ARG0]][%[[IDX]]]
+
+// -----
+
+func.func @fold_vector_maskedload_expand_shape(
+  %arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+  %1 = vector.maskedload %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+  return %1 : vector<8xf32>
+}
+
+//   CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
+// CHECK-LABEL: func @fold_vector_maskedload_expand_shape
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
+//  CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
+//       CHECK:   %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
+//       CHECK:   vector.maskedload %[[ARG0]][%[[IDX]]], %[[ARG3]], %[[ARG4]]
+
+// -----
+
+func.func @fold_vector_store_expand_shape(
+  %arg0 : memref<32xf32>, %arg1 : index, %val : vector<8xf32>) {
+  %c0 = arith.constant 0 : index
+  %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+  vector.store %val, %0[%arg1, %c0] : memref<4x8xf32>, vector<8xf32>
+  return
+}
+
+//   CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
+// CHECK-LABEL: func @fold_vector_store_expand_shape
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//       CHECK:   %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
+//       CHECK:   vector.store %{{.*}}, %[[ARG0]][%[[IDX]]]
+
+// -----
+
+func.func @fold_vector_maskedstore_expand_shape(
+  %arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) {
+  %c0 = arith.constant 0 : index
+  %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+  vector.maskedstore %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32>
+  return
+}
+
+//   CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
+// CHECK-LABEL: func @fold_vector_maskedstore_expand_shape
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
+//  CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
+//       CHECK:   %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
+//       CHECK:   vector.maskedstore %[[ARG0]][%[[IDX]]], %[[ARG3]], %[[ARG4]]
+
+// -----
+
+func.func @fold_vector_load_collapse_shape(
+  %arg0 : memref<4x8xf32>, %arg1 : index) -> vector<8xf32> {
+  %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
+  %1 = vector.load %0[%arg1] : memref<32xf32>, vector<8xf32>
+  return %1 : vector<8xf32>
+}
+
+//   CHECK-DAG: #[[$MAP:.*]]  = affine_map<()[s0] -> (s0 floordiv 8)>
+//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
+// CHECK-LABEL: func @fold_vector_load_collapse_shape
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//       CHECK:   %[[IDX:.*]] = affine.apply  #[[$MAP]]()[%[[ARG1]]]
+//       CHECK:   %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
+//       CHECK:   vector.load %[[ARG0]][%[[IDX]], %[[IDX1]]]
+
+// -----
+
+func.func @fold_vector_maskedload_collapse_shape(
+  %arg0 : memref<4x8xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> {
+  %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
+  %1 = vector.maskedload %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+  return %1 : vector<8xf32>
+}
+
+//   CHECK-DAG: #[[$MAP:.*]]  = affine_map<()[s0] -> (s0 floordiv 8)>
+//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
+// CHECK-LABEL: func @fold_vector_maskedload_collapse_shape
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
+//  CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
+//       CHECK:   %[[IDX:.*]] = affine.apply  #[[$MAP]]()[%[[ARG1]]]
+//       CHECK:   %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
+//       CHECK:   vector.maskedload %[[ARG0]][%[[IDX]], %[[IDX1]]], %[[ARG3]], %[[ARG4]]
+
+// -----
+
+func.func @fold_vector_store_collapse_shape(
+  %arg0 : memref<4x8xf32>, %arg1 : index, %val : vector<8xf32>) {
+  %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
+  vector.store %val, %0[%arg1] : memref<32xf32>, vector<8xf32>
+  return
+}
+
+//   CHECK-DAG: #[[$MAP:.*]]  = affine_map<()[s0] -> (s0 floordiv 8)>
+//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
+// CHECK-LABEL: func @fold_vector_store_collapse_shape
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//       CHECK:   %[[IDX:.*]] = affine.apply  #[[$MAP]]()[%[[ARG1]]]
+//       CHECK:   %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
+//       CHECK:   vector.store %{{.*}}, %[[ARG0]][%[[IDX]], %[[IDX1]]]
+
+// -----
+
+func.func @fold_vector_maskedstore_collapse_shape(
+  %arg0 : memref<4x8xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) {
+  %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
+  vector.maskedstore %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xi1>, vector<8xf32>
+  return
+}
+
+//   CHECK-DAG: #[[$MAP:.*]]  = affine_map<()[s0] -> (s0 floordiv 8)>
+//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
+// CHECK-LABEL: func @fold_vector_maskedstore_collapse_shape
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
+//  CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
+//       CHECK:   %[[IDX:.*]] = affine.apply  #[[$MAP]]()[%[[ARG1]]]
+//       CHECK:   %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
+//       CHECK:   vector.maskedstore %[[ARG0]][%[[IDX]], %[[IDX1]]], %[[ARG3]], %[[ARG4]]



More information about the Mlir-commits mailing list