[Mlir-commits] [mlir] dae3c44 - [mlir] Add `vector.store/maskedstore` of `memref.subview` memref alias folding (#72184)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 14 14:24:58 PST 2023


Author: Max191
Date: 2023-11-14T14:24:54-08:00
New Revision: dae3c44ce6736aec63b167a6c1da10892584bc75

URL: https://github.com/llvm/llvm-project/commit/dae3c44ce6736aec63b167a6c1da10892584bc75
DIFF: https://github.com/llvm/llvm-project/commit/dae3c44ce6736aec63b167a6c1da10892584bc75.diff

LOG: [mlir] Add `vector.store/maskedstore` of `memref.subview` memref alias folding (#72184)

Fixes https://github.com/openxla/iree/issues/15575

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
    mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index b78c4510ff88585..aa44455ada7f9aa 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -187,8 +187,12 @@ static Value getMemRefOperand(nvgpu::LdMatrixOp op) {
 
 static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); }
 
+static Value getMemRefOperand(vector::StoreOp op) { return op.getBase(); }
+
 static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); }
 
+static Value getMemRefOperand(vector::MaskedStoreOp op) { return op.getBase(); }
+
 static Value getMemRefOperand(vector::TransferWriteOp op) {
   return op.getSource();
 }
@@ -557,6 +561,15 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
                 subViewOp.getDroppedDims())),
             op.getMask(), op.getInBoundsAttr());
       })
+      .Case([&](vector::StoreOp op) {
+        rewriter.replaceOpWithNewOp<vector::StoreOp>(
+            op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
+      })
+      .Case([&](vector::MaskedStoreOp op) {
+        rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
+            op, subViewOp.getSource(), sourceIndices, op.getMask(),
+            op.getValueToStore());
+      })
       .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
         rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>(
             op, op.getSrc(), subViewOp.getSource(), sourceIndices,
@@ -698,6 +711,8 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
                StoreOpOfSubViewOpFolder<affine::AffineStoreOp>,
                StoreOpOfSubViewOpFolder<memref::StoreOp>,
                StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
+               StoreOpOfSubViewOpFolder<vector::StoreOp>,
+               StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
                StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
                LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>,
                LoadOpOfExpandShapeOpFolder<memref::LoadOp>,

diff  --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 8fe87bc8c57c300..96b72e042b9e0d6 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -803,3 +803,38 @@ func.func @fold_vector_maskedload(
 // CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<32xi1>
 // CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32>
 //      CHECK:   vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32> into vector<32xf32>
+
+// -----
+
+func.func @fold_vector_store(
+  %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<2x32xf32>) -> () {
+  %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
+  vector.store %arg3, %0[] : memref<f32, strided<[], offset: ?>>, vector<2x32xf32>
+  return
+}
+
+//      CHECK: func @fold_vector_store
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<2x32xf32>
+//      CHECK:   vector.store %[[ARG3]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] :  memref<12x32xf32>, vector<2x32xf32>
+//      CHECK:   return
+
+// -----
+
+func.func @fold_vector_maskedstore(
+  %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> () {
+  %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
+  vector.maskedstore %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32>
+  return
+}
+
+//      CHECK: func @fold_vector_maskedstore
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<32xi1>
+// CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32>
+//      CHECK:   vector.maskedstore %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32>
+//      CHECK:   return


        


More information about the Mlir-commits mailing list