[Mlir-commits] [mlir] 57e4360 - [mlir][memref] Add memref alias folders for expand/collapse_shape for vector load/store (#95223)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 12 07:36:21 PDT 2024


Author: Kunwar Grover
Date: 2024-06-12T15:36:16+01:00
New Revision: 57e4360836f421a2c6131de51e3845620c6aea76

URL: https://github.com/llvm/llvm-project/commit/57e4360836f421a2c6131de51e3845620c6aea76
DIFF: https://github.com/llvm/llvm-project/commit/57e4360836f421a2c6131de51e3845620c6aea76.diff

LOG: [mlir][memref] Add memref alias folders for expand/collapse_shape for vector load/store (#95223)

This patch adds adds patterns to fold memref alias for
expand_shape/collapse_shape feeding into vector.load/vector.store and
vector.maskedload/vector.maskedstore

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
    mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index db085b386483c..96daf4c5972a4 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -518,10 +518,25 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
           loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
     return failure();
   llvm::TypeSwitch<Operation *, void>(loadOp)
-      .Case<affine::AffineLoadOp, memref::LoadOp>([&](auto op) {
-        rewriter.replaceOpWithNewOp<decltype(op)>(
+      .Case([&](affine::AffineLoadOp op) {
+        rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
             loadOp, expandShapeOp.getViewSource(), sourceIndices);
       })
+      .Case([&](memref::LoadOp op) {
+        rewriter.replaceOpWithNewOp<memref::LoadOp>(
+            loadOp, expandShapeOp.getViewSource(), sourceIndices,
+            op.getNontemporal());
+      })
+      .Case([&](vector::LoadOp op) {
+        rewriter.replaceOpWithNewOp<vector::LoadOp>(
+            op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
+            op.getNontemporal());
+      })
+      .Case([&](vector::MaskedLoadOp op) {
+        rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
+            op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
+            op.getMask(), op.getPassThru());
+      })
       .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
   return success();
 }
@@ -551,10 +566,25 @@ LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
           loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
     return failure();
   llvm::TypeSwitch<Operation *, void>(loadOp)
-      .Case<affine::AffineLoadOp, memref::LoadOp>([&](auto op) {
-        rewriter.replaceOpWithNewOp<decltype(op)>(
+      .Case([&](affine::AffineLoadOp op) {
+        rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
             loadOp, collapseShapeOp.getViewSource(), sourceIndices);
       })
+      .Case([&](memref::LoadOp op) {
+        rewriter.replaceOpWithNewOp<memref::LoadOp>(
+            loadOp, collapseShapeOp.getViewSource(), sourceIndices,
+            op.getNontemporal());
+      })
+      .Case([&](vector::LoadOp op) {
+        rewriter.replaceOpWithNewOp<vector::LoadOp>(
+            op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
+            op.getNontemporal());
+      })
+      .Case([&](vector::MaskedLoadOp op) {
+        rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
+            op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
+            op.getMask(), op.getPassThru());
+      })
       .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
   return success();
 }
@@ -651,10 +681,25 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
           storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
     return failure();
   llvm::TypeSwitch<Operation *, void>(storeOp)
-      .Case<affine::AffineStoreOp, memref::StoreOp>([&](auto op) {
-        rewriter.replaceOpWithNewOp<decltype(op)>(storeOp, storeOp.getValue(),
-                                                  expandShapeOp.getViewSource(),
-                                                  sourceIndices);
+      .Case([&](affine::AffineStoreOp op) {
+        rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
+            storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
+            sourceIndices);
+      })
+      .Case([&](memref::StoreOp op) {
+        rewriter.replaceOpWithNewOp<memref::StoreOp>(
+            storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
+            sourceIndices, op.getNontemporal());
+      })
+      .Case([&](vector::StoreOp op) {
+        rewriter.replaceOpWithNewOp<vector::StoreOp>(
+            op, op.getValueToStore(), expandShapeOp.getViewSource(),
+            sourceIndices, op.getNontemporal());
+      })
+      .Case([&](vector::MaskedStoreOp op) {
+        rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
+            op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
+            op.getValueToStore());
       })
       .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
   return success();
@@ -685,11 +730,26 @@ LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
           storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
     return failure();
   llvm::TypeSwitch<Operation *, void>(storeOp)
-      .Case<affine::AffineStoreOp, memref::StoreOp>([&](auto op) {
-        rewriter.replaceOpWithNewOp<decltype(op)>(
-            storeOp, storeOp.getValue(), collapseShapeOp.getViewSource(),
+      .Case([&](affine::AffineStoreOp op) {
+        rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
+            storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
             sourceIndices);
       })
+      .Case([&](memref::StoreOp op) {
+        rewriter.replaceOpWithNewOp<memref::StoreOp>(
+            storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
+            sourceIndices, op.getNontemporal());
+      })
+      .Case([&](vector::StoreOp op) {
+        rewriter.replaceOpWithNewOp<vector::StoreOp>(
+            op, op.getValueToStore(), collapseShapeOp.getViewSource(),
+            sourceIndices, op.getNontemporal());
+      })
+      .Case([&](vector::MaskedStoreOp op) {
+        rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
+            op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
+            op.getValueToStore());
+      })
       .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
   return success();
 }
@@ -763,12 +823,20 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
                StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
                LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>,
                LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
+               LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
+               LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
                StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>,
                StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
+               StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
+               StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
                LoadOpOfCollapseShapeOpFolder<affine::AffineLoadOp>,
                LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
+               LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
+               LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
                StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>,
                StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
+               StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
+               StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
                SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
       patterns.getContext());
 }

diff  --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index e49dff44ae0d6..327cacf7d9a20 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -473,10 +473,10 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar
 func.func @fold_dynamic_subview_with_memref_load_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index, %sz0: index) -> f32 {
   %c0 = arith.constant 0 : index
   %expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 16, %sz0, 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
-  %0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
+  %0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
   return %0 : f32
 }
-// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<16x?xf32, strided<[16, 1]>>
+// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
 // CHECK-NEXT: return %[[VAL1]] : f32
 
 // -----
@@ -487,11 +487,11 @@ func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16
   %c0 = arith.constant 0 : index
   %c1f32 = arith.constant 1.0 : f32
   %expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 16, %sz0, 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
-  memref.store %c1f32, %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
+  memref.store %c1f32, %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
   return
 }
 // CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<16x?xf32, strided<[16, 1]>>
+// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
 // CHECK-NEXT: return
 
 // -----
@@ -819,14 +819,14 @@ func.func @test_ldmatrix(%arg0: memref<4x32x32xf16, 3>, %arg1: index, %arg2: ind
 
 // -----
 
-func.func @fold_vector_load(
+func.func @fold_vector_load_subview(
   %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index) -> vector<12x32xf32> {
   %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
   %1 = vector.load %0[] : memref<f32, strided<[], offset: ?>>, vector<12x32xf32>
   return %1 : vector<12x32xf32>
 }
 
-//      CHECK: func @fold_vector_load
+//      CHECK: func @fold_vector_load_subview
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
@@ -834,14 +834,14 @@ func.func @fold_vector_load(
 
 // -----
 
-func.func @fold_vector_maskedload(
+func.func @fold_vector_maskedload_subview(
   %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> vector<32xf32> {
   %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
   %1 = vector.maskedload %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32> into vector<32xf32>
   return %1 : vector<32xf32>
 }
 
-//      CHECK: func @fold_vector_maskedload
+//      CHECK: func @fold_vector_maskedload_subview
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
@@ -851,14 +851,14 @@ func.func @fold_vector_maskedload(
 
 // -----
 
-func.func @fold_vector_store(
+func.func @fold_vector_store_subview(
   %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<2x32xf32>) -> () {
   %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
   vector.store %arg3, %0[] : memref<f32, strided<[], offset: ?>>, vector<2x32xf32>
   return
 }
 
-//      CHECK: func @fold_vector_store
+//      CHECK: func @fold_vector_store_subview
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
@@ -868,14 +868,14 @@ func.func @fold_vector_store(
 
 // -----
 
-func.func @fold_vector_maskedstore(
+func.func @fold_vector_maskedstore_subview(
   %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> () {
   %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
   vector.maskedstore %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32>
   return
 }
 
-//      CHECK: func @fold_vector_maskedstore
+//      CHECK: func @fold_vector_maskedstore_subview
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
@@ -883,3 +883,151 @@ func.func @fold_vector_maskedstore(
 // CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32>
 //      CHECK:   vector.maskedstore %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32>
 //      CHECK:   return
+
+// -----
+
+func.func @fold_vector_load_expand_shape(
+  %arg0 : memref<32xf32>, %arg1 : index) -> vector<8xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+  %1 = vector.load %0[%arg1, %c0] {nontemporal = true} : memref<4x8xf32>, vector<8xf32>
+  return %1 : vector<8xf32>
+}
+
+//   CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
+// CHECK-LABEL: func @fold_vector_load_expand_shape
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//       CHECK:   %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
+//       CHECK:   vector.load %[[ARG0]][%[[IDX]]] {nontemporal = true}
+
+// -----
+
+func.func @fold_vector_maskedload_expand_shape(
+  %arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+  %1 = vector.maskedload %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+  return %1 : vector<8xf32>
+}
+
+//   CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
+// CHECK-LABEL: func @fold_vector_maskedload_expand_shape
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
+//  CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
+//       CHECK:   %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
+//       CHECK:   vector.maskedload %[[ARG0]][%[[IDX]]], %[[ARG3]], %[[ARG4]]
+
+// -----
+
+func.func @fold_vector_store_expand_shape(
+  %arg0 : memref<32xf32>, %arg1 : index, %val : vector<8xf32>) {
+  %c0 = arith.constant 0 : index
+  %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+  vector.store %val, %0[%arg1, %c0] {nontemporal = true} : memref<4x8xf32>, vector<8xf32>
+  return
+}
+
+//   CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
+// CHECK-LABEL: func @fold_vector_store_expand_shape
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//       CHECK:   %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
+//       CHECK:   vector.store %{{.*}}, %[[ARG0]][%[[IDX]]] {nontemporal = true}
+
+// -----
+
+func.func @fold_vector_maskedstore_expand_shape(
+  %arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) {
+  %c0 = arith.constant 0 : index
+  %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+  vector.maskedstore %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32>
+  return
+}
+
+//   CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
+// CHECK-LABEL: func @fold_vector_maskedstore_expand_shape
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
+//  CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
+//       CHECK:   %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
+//       CHECK:   vector.maskedstore %[[ARG0]][%[[IDX]]], %[[ARG3]], %[[ARG4]]
+
+// -----
+
+func.func @fold_vector_load_collapse_shape(
+  %arg0 : memref<4x8xf32>, %arg1 : index) -> vector<8xf32> {
+  %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
+  %1 = vector.load %0[%arg1] {nontemporal = true} : memref<32xf32>, vector<8xf32>
+  return %1 : vector<8xf32>
+}
+
+//   CHECK-DAG: #[[$MAP:.*]]  = affine_map<()[s0] -> (s0 floordiv 8)>
+//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
+// CHECK-LABEL: func @fold_vector_load_collapse_shape
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//       CHECK:   %[[IDX:.*]] = affine.apply  #[[$MAP]]()[%[[ARG1]]]
+//       CHECK:   %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
+//       CHECK:   vector.load %[[ARG0]][%[[IDX]], %[[IDX1]]] {nontemporal = true}
+
+// -----
+
+func.func @fold_vector_maskedload_collapse_shape(
+  %arg0 : memref<4x8xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> {
+  %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
+  %1 = vector.maskedload %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+  return %1 : vector<8xf32>
+}
+
+//   CHECK-DAG: #[[$MAP:.*]]  = affine_map<()[s0] -> (s0 floordiv 8)>
+//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
+// CHECK-LABEL: func @fold_vector_maskedload_collapse_shape
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
+//  CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
+//       CHECK:   %[[IDX:.*]] = affine.apply  #[[$MAP]]()[%[[ARG1]]]
+//       CHECK:   %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
+//       CHECK:   vector.maskedload %[[ARG0]][%[[IDX]], %[[IDX1]]], %[[ARG3]], %[[ARG4]]
+
+// -----
+
+func.func @fold_vector_store_collapse_shape(
+  %arg0 : memref<4x8xf32>, %arg1 : index, %val : vector<8xf32>) {
+  %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
+  vector.store %val, %0[%arg1] {nontemporal = true} : memref<32xf32>, vector<8xf32>
+  return
+}
+
+//   CHECK-DAG: #[[$MAP:.*]]  = affine_map<()[s0] -> (s0 floordiv 8)>
+//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
+// CHECK-LABEL: func @fold_vector_store_collapse_shape
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//       CHECK:   %[[IDX:.*]] = affine.apply  #[[$MAP]]()[%[[ARG1]]]
+//       CHECK:   %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
+//       CHECK:   vector.store %{{.*}}, %[[ARG0]][%[[IDX]], %[[IDX1]]] {nontemporal = true}
+
+// -----
+
+func.func @fold_vector_maskedstore_collapse_shape(
+  %arg0 : memref<4x8xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) {
+  %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
+  vector.maskedstore %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xi1>, vector<8xf32>
+  return
+}
+
+//   CHECK-DAG: #[[$MAP:.*]]  = affine_map<()[s0] -> (s0 floordiv 8)>
+//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
+// CHECK-LABEL: func @fold_vector_maskedstore_collapse_shape
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
+//  CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
+//       CHECK:   %[[IDX:.*]] = affine.apply  #[[$MAP]]()[%[[ARG1]]]
+//       CHECK:   %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
+//       CHECK:   vector.maskedstore %[[ARG0]][%[[IDX]], %[[IDX1]]], %[[ARG3]], %[[ARG4]]


        


More information about the Mlir-commits mailing list