[Mlir-commits] [mlir] [mlir][memref] Add memref alias folding for masked transfers (PR #71476)

Quinn Dawkins llvmlistbot at llvm.org
Mon Nov 6 18:00:28 PST 2023


https://github.com/qedawkins created https://github.com/llvm/llvm-project/pull/71476

Because masking of vector.transfer ops apply to the unpermuted input vector (for reads) and permuted output vector (for writes), they apply independently of any subviews and can thus be forwarded to the folded transfers.

>From 89cbc80242c2598c610d7e6632a4cadeb7e97e10 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Mon, 6 Nov 2023 09:03:48 -0500
Subject: [PATCH] [MLIR][MemRef] Add memref alias folding for masked transfers

Because masking of vector.transfer ops semantically apply to the
unpermuted input vector (for reads) and permuted output vector (for
writes), they apply independently of any subviews.
---
 .../MemRef/Transforms/FoldMemRefAliasOps.cpp  |  6 +-
 .../Dialect/MemRef/fold-memref-alias-ops.mlir | 57 +++++++++++++++++++
 2 files changed, 59 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 043e8fbcdd2f6fb..b78c4510ff88585 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -346,8 +346,6 @@ preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp,
       "must be a vector transfer op");
   if (xferOp.hasOutOfBoundsDim())
     return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
-  if (xferOp.getMask())
-    return rewriter.notifyMatchFailure(xferOp, "masked transfer");
   if (!subviewOp.hasUnitStride()) {
     return rewriter.notifyMatchFailure(
         xferOp, "non-1 stride subview, need to track strides in folded memref");
@@ -428,7 +426,7 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
             AffineMapAttr::get(expandDimsToRank(
                 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
                 subViewOp.getDroppedDims())),
-            op.getPadding(), /*mask=*/Value(), op.getInBoundsAttr());
+            op.getPadding(), op.getMask(), op.getInBoundsAttr());
       })
       .Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
         rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>(
@@ -557,7 +555,7 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
             AffineMapAttr::get(expandDimsToRank(
                 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
                 subViewOp.getDroppedDims())),
-            op.getInBoundsAttr());
+            op.getMask(), op.getInBoundsAttr());
       })
       .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
         rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>(
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 3f11e22749bb16d..2e1319420fb3eaf 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -266,6 +266,63 @@ func.func @fold_vector_transfer_write_with_inner_rank_reduced_subview(
 
 // -----
 
+func.func @fold_masked_vector_transfer_read_with_subview(
+    %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
+    %arg1: index, %arg2 : index, %arg3 : index, %arg4: index, %arg5 : index,
+    %arg6 : index, %mask : vector<4xi1>) -> vector<4xf32> {
+  %cst = arith.constant 0.0 : f32
+  %0 = memref.subview %arg0[%arg1, %arg2] [%arg3, %arg4] [1, 1]
+      : memref<?x?xf32, strided<[?, ?], offset: ?>> to
+        memref<?x?xf32, strided<[?, ?], offset: ?>>
+  %1 = vector.transfer_read %0[%arg5, %arg6], %cst, %mask {in_bounds = [true]}
+      : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4xf32>
+  return %1 : vector<4xf32>
+}
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+//       CHECK: func @fold_masked_vector_transfer_read_with_subview
+//  CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32, strided<[?, ?], offset: ?>>
+//  CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG3:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG4:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG5:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG6:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG7:[a-zA-Z0-9]+]]: vector<4xi1>
+//   CHECK-DAG:    %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG1]], %[[ARG5]]]
+//   CHECK-DAG:    %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
+//       CHECK:    vector.transfer_read %[[ARG0]][%[[IDX0]], %[[IDX1]]], %{{.*}}, %[[ARG7]] {{.*}} : memref<?x?xf32
+
+// -----
+
+func.func @fold_masked_vector_transfer_write_with_subview(
+    %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
+    %arg1 : vector<4xf32>, %arg2: index, %arg3 : index, %arg4 : index,
+    %arg5: index, %arg6 : index, %arg7 : index, %mask : vector<4xi1>) {
+  %cst = arith.constant 0.0 : f32
+  %0 = memref.subview %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1]
+      : memref<?x?xf32, strided<[?, ?], offset: ?>> to
+        memref<?x?xf32, strided<[?, ?], offset: ?>>
+  vector.transfer_write %arg1, %0[%arg6, %arg7], %mask {in_bounds = [true]}
+      : vector<4xf32>, memref<?x?xf32, strided<[?, ?], offset: ?>>
+  return
+}
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+//       CHECK: func @fold_masked_vector_transfer_write_with_subview
+//  CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32, strided<[?, ?], offset: ?>>
+//  CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: vector<4xf32>
+//  CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG3:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG4:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG5:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG6:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG7:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG8:[a-zA-Z0-9]+]]: vector<4xi1>
+//   CHECK-DAG:    %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
+//   CHECK-DAG:    %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG7]]]
+//   CHECK-DAG:    vector.transfer_write %[[ARG1]], %[[ARG0]][%[[IDX0]], %[[IDX1]]], %[[ARG8]] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32
+
+// -----
+
 //  Test with affine.load/store ops. We only do a basic test here since the
 //  logic is identical to that with memref.load/store ops. The same affine.apply
 //  ops would be generated.



More information about the Mlir-commits mailing list