[Mlir-commits] [mlir] [mlir] Add `vector.store` of `memref.subview` memref alias folding (PR #72184)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 13 17:12:59 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-memref
Author: None (Max191)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/72184.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp (+7)
- (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+17)
``````````diff
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index b78c4510ff88585..85878aff2701ffe 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -187,6 +187,8 @@ static Value getMemRefOperand(nvgpu::LdMatrixOp op) {
static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); }
+static Value getMemRefOperand(vector::StoreOp op) { return op.getBase(); }
+
static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); }
static Value getMemRefOperand(vector::TransferWriteOp op) {
@@ -557,6 +559,10 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
subViewOp.getDroppedDims())),
op.getMask(), op.getInBoundsAttr());
})
+ .Case([&](vector::StoreOp op) {
+ rewriter.replaceOpWithNewOp<vector::StoreOp>(
+ op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
+ })
.Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>(
op, op.getSrc(), subViewOp.getSource(), sourceIndices,
@@ -698,6 +704,7 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
StoreOpOfSubViewOpFolder<affine::AffineStoreOp>,
StoreOpOfSubViewOpFolder<memref::StoreOp>,
StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
+ StoreOpOfSubViewOpFolder<vector::StoreOp>,
StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>,
LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 8fe87bc8c57c300..85bc51fcf343d2d 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -803,3 +803,20 @@ func.func @fold_vector_maskedload(
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<32xi1>
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32>
// CHECK: vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32> into vector<32xf32>
+
+// -----
+
+func.func @fold_vector_store(
+ %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<2x32xf32>) -> () {
+ %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
+ vector.store %arg3, %0[] : memref<f32, strided<[], offset: ?>>, vector<2x32xf32>
+ return
+}
+
+// CHECK: func @fold_vector_store
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<2x32xf32>
+// CHECK: vector.store %[[ARG3]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<12x32xf32>, vector<2x32xf32>
+// CHECK: return
\ No newline at end of file
``````````
</details>
https://github.com/llvm/llvm-project/pull/72184
More information about the Mlir-commits
mailing list