[Mlir-commits] [mlir] [mlir][memref] Add memref alias folding for masked transfers (PR #71476)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 6 18:00:57 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Quinn Dawkins (qedawkins)

<details>
<summary>Changes</summary>

Because masking of vector.transfer ops apply to the unpermuted input vector (for reads) and permuted output vector (for writes), they apply independently of any subviews and can thus be forwarded to the folded transfers.

---
Full diff: https://github.com/llvm/llvm-project/pull/71476.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp (+2-4) 
- (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+57) 


``````````diff
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 043e8fbcdd2f6fb..b78c4510ff88585 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -346,8 +346,6 @@ preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp,
       "must be a vector transfer op");
   if (xferOp.hasOutOfBoundsDim())
     return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
-  if (xferOp.getMask())
-    return rewriter.notifyMatchFailure(xferOp, "masked transfer");
   if (!subviewOp.hasUnitStride()) {
     return rewriter.notifyMatchFailure(
         xferOp, "non-1 stride subview, need to track strides in folded memref");
@@ -428,7 +426,7 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
             AffineMapAttr::get(expandDimsToRank(
                 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
                 subViewOp.getDroppedDims())),
-            op.getPadding(), /*mask=*/Value(), op.getInBoundsAttr());
+            op.getPadding(), op.getMask(), op.getInBoundsAttr());
       })
       .Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
         rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>(
@@ -557,7 +555,7 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
             AffineMapAttr::get(expandDimsToRank(
                 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
                 subViewOp.getDroppedDims())),
-            op.getInBoundsAttr());
+            op.getMask(), op.getInBoundsAttr());
       })
       .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
         rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>(
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 3f11e22749bb16d..2e1319420fb3eaf 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -266,6 +266,63 @@ func.func @fold_vector_transfer_write_with_inner_rank_reduced_subview(
 
 // -----
 
+func.func @fold_masked_vector_transfer_read_with_subview(
+    %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
+    %arg1: index, %arg2 : index, %arg3 : index, %arg4: index, %arg5 : index,
+    %arg6 : index, %mask : vector<4xi1>) -> vector<4xf32> {
+  %cst = arith.constant 0.0 : f32
+  %0 = memref.subview %arg0[%arg1, %arg2] [%arg3, %arg4] [1, 1]
+      : memref<?x?xf32, strided<[?, ?], offset: ?>> to
+        memref<?x?xf32, strided<[?, ?], offset: ?>>
+  %1 = vector.transfer_read %0[%arg5, %arg6], %cst, %mask {in_bounds = [true]}
+      : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4xf32>
+  return %1 : vector<4xf32>
+}
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+//       CHECK: func @fold_masked_vector_transfer_read_with_subview
+//  CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32, strided<[?, ?], offset: ?>>
+//  CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG3:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG4:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG5:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG6:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG7:[a-zA-Z0-9]+]]: vector<4xi1>
+//   CHECK-DAG:    %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG1]], %[[ARG5]]]
+//   CHECK-DAG:    %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
+//       CHECK:    vector.transfer_read %[[ARG0]][%[[IDX0]], %[[IDX1]]], %{{.*}}, %[[ARG7]] {{.*}} : memref<?x?xf32
+
+// -----
+
+func.func @fold_masked_vector_transfer_write_with_subview(
+    %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
+    %arg1 : vector<4xf32>, %arg2: index, %arg3 : index, %arg4 : index,
+    %arg5: index, %arg6 : index, %arg7 : index, %mask : vector<4xi1>) {
+  %cst = arith.constant 0.0 : f32
+  %0 = memref.subview %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1]
+      : memref<?x?xf32, strided<[?, ?], offset: ?>> to
+        memref<?x?xf32, strided<[?, ?], offset: ?>>
+  vector.transfer_write %arg1, %0[%arg6, %arg7], %mask {in_bounds = [true]}
+      : vector<4xf32>, memref<?x?xf32, strided<[?, ?], offset: ?>>
+  return
+}
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+//       CHECK: func @fold_masked_vector_transfer_write_with_subview
+//  CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32, strided<[?, ?], offset: ?>>
+//  CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: vector<4xf32>
+//  CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG3:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG4:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG5:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG6:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG7:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG8:[a-zA-Z0-9]+]]: vector<4xi1>
+//   CHECK-DAG:    %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
+//   CHECK-DAG:    %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG7]]]
+//   CHECK-DAG:    vector.transfer_write %[[ARG1]], %[[ARG0]][%[[IDX0]], %[[IDX1]]], %[[ARG8]] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32
+
+// -----
+
 //  Test with affine.load/store ops. We only do a basic test here since the
 //  logic is identical to that with memref.load/store ops. The same affine.apply
 //  ops would be generated.

``````````

</details>


https://github.com/llvm/llvm-project/pull/71476


More information about the Mlir-commits mailing list