[Mlir-commits] [mlir] 48f980c - [mlir][memref] Add memref alias folding for masked transfers (#71476)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 7 05:56:59 PST 2023


Author: Quinn Dawkins
Date: 2023-11-07T08:56:54-05:00
New Revision: 48f980c535ca9a4e2c00c56734e6d3346f4b0a86

URL: https://github.com/llvm/llvm-project/commit/48f980c535ca9a4e2c00c56734e6d3346f4b0a86
DIFF: https://github.com/llvm/llvm-project/commit/48f980c535ca9a4e2c00c56734e6d3346f4b0a86.diff

LOG: [mlir][memref] Add memref alias folding for masked transfers (#71476)

The contents of a mask on a masked transfer are unaffected by the
particular region of memory being read/stored to, so just forward the
mask in subview folding patterns.

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
    mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 043e8fbcdd2f6fb..b78c4510ff88585 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -346,8 +346,6 @@ preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp,
       "must be a vector transfer op");
   if (xferOp.hasOutOfBoundsDim())
     return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
-  if (xferOp.getMask())
-    return rewriter.notifyMatchFailure(xferOp, "masked transfer");
   if (!subviewOp.hasUnitStride()) {
     return rewriter.notifyMatchFailure(
         xferOp, "non-1 stride subview, need to track strides in folded memref");
@@ -428,7 +426,7 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
             AffineMapAttr::get(expandDimsToRank(
                 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
                 subViewOp.getDroppedDims())),
-            op.getPadding(), /*mask=*/Value(), op.getInBoundsAttr());
+            op.getPadding(), op.getMask(), op.getInBoundsAttr());
       })
       .Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
         rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>(
@@ -557,7 +555,7 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
             AffineMapAttr::get(expandDimsToRank(
                 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
                 subViewOp.getDroppedDims())),
-            op.getInBoundsAttr());
+            op.getMask(), op.getInBoundsAttr());
       })
       .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
         rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>(

diff  --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 3f11e22749bb16d..8fe87bc8c57c300 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -266,6 +266,127 @@ func.func @fold_vector_transfer_write_with_inner_rank_reduced_subview(
 
 // -----
 
+func.func @fold_masked_vector_transfer_read_with_subview(
+    %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
+    %arg1: index, %arg2 : index, %arg3 : index, %arg4: index, %arg5 : index,
+    %arg6 : index, %mask : vector<4xi1>) -> vector<4xf32> {
+  %cst = arith.constant 0.0 : f32
+  %0 = memref.subview %arg0[%arg1, %arg2] [%arg3, %arg4] [1, 1]
+      : memref<?x?xf32, strided<[?, ?], offset: ?>> to
+        memref<?x?xf32, strided<[?, ?], offset: ?>>
+  %1 = vector.transfer_read %0[%arg5, %arg6], %cst, %mask {in_bounds = [true]}
+      : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4xf32>
+  return %1 : vector<4xf32>
+}
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+//       CHECK: func @fold_masked_vector_transfer_read_with_subview
+//  CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32, strided<[?, ?], offset: ?>>
+//  CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG3:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG4:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG5:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG6:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[MASK:[a-zA-Z0-9]+]]: vector<4xi1>
+//   CHECK-DAG:    %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG1]], %[[ARG5]]]
+//   CHECK-DAG:    %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
+//       CHECK:    vector.transfer_read %[[ARG0]][%[[IDX0]], %[[IDX1]]], %{{.*}}, %[[MASK]] {{.*}} : memref<?x?xf32
+
+// -----
+
+func.func @fold_masked_vector_transfer_read_with_rank_reducing_subview(
+    %arg0 : memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>,
+    %arg1: index, %arg2 : index, %arg3 : index, %arg4: index, %arg5 : index,
+    %arg6 : index, %mask : vector<4x3xi1>) -> vector<3x4xf32> {
+  %cst = arith.constant 0.0 : f32
+  %0 = memref.subview %arg0[0, %arg1, 0, %arg2] [1, %arg3, 1, %arg4] [1, 1, 1, 1]
+      : memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>> to
+        memref<?x?xf32, strided<[?, ?], offset: ?>>
+  %1 = vector.transfer_read %0[%arg5, %arg6], %cst, %mask {
+         permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]}
+      : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<3x4xf32>
+  return %1 : vector<3x4xf32>
+}
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d1)>
+//       CHECK: func @fold_masked_vector_transfer_read_with_rank_reducing_subview
+//  CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>
+//  CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG3:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG4:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG5:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG6:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[MASK:[a-zA-Z0-9]+]]: vector<4x3xi1>
+//   CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:    %[[PAD:.+]] = arith.constant 0.000000e+00 : f32
+//   CHECK-DAG:    %[[IDX0:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[ARG5]]]
+//   CHECK-DAG:    %[[IDX1:.+]] = affine.apply #[[MAP0]]()[%[[ARG2]], %[[ARG6]]]
+//       CHECK:    vector.transfer_read %[[ARG0]][%[[C0]], %[[IDX0]], %[[C0]], %[[IDX1]]], %[[PAD]], %[[MASK]] {{.*}} permutation_map = #[[MAP1]]} : memref<?x?x?x?xf32
+
+// -----
+
+func.func @fold_masked_vector_transfer_write_with_subview(
+    %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
+    %arg1 : vector<4xf32>, %arg2: index, %arg3 : index, %arg4 : index,
+    %arg5: index, %arg6 : index, %arg7 : index, %mask : vector<4xi1>) {
+  %cst = arith.constant 0.0 : f32
+  %0 = memref.subview %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1]
+      : memref<?x?xf32, strided<[?, ?], offset: ?>> to
+        memref<?x?xf32, strided<[?, ?], offset: ?>>
+  vector.transfer_write %arg1, %0[%arg6, %arg7], %mask {in_bounds = [true]}
+      : vector<4xf32>, memref<?x?xf32, strided<[?, ?], offset: ?>>
+  return
+}
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+//       CHECK: func @fold_masked_vector_transfer_write_with_subview
+//  CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32, strided<[?, ?], offset: ?>>
+//  CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: vector<4xf32>
+//  CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG3:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG4:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG5:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG6:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG7:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[MASK:[a-zA-Z0-9]+]]: vector<4xi1>
+//   CHECK-DAG:    %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
+//   CHECK-DAG:    %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG7]]]
+//   CHECK-DAG:    vector.transfer_write %[[ARG1]], %[[ARG0]][%[[IDX0]], %[[IDX1]]], %[[MASK]] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32
+
+// -----
+
+func.func @fold_masked_vector_transfer_write_with_rank_reducing_subview(
+    %arg0 : memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>,
+    %arg1 : vector<3x4xf32>, %arg2: index, %arg3 : index, %arg4 : index,
+    %arg5: index, %arg6 : index, %arg7 : index, %mask : vector<4x3xi1>) {
+  %cst = arith.constant 0.0 : f32
+  %0 = memref.subview %arg0[0, %arg2, 0, %arg3] [1, %arg4, 1, %arg5] [1, 1, 1, 1]
+      : memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>> to
+        memref<?x?xf32, strided<[?, ?], offset: ?>>
+  vector.transfer_write %arg1, %0[%arg6, %arg7], %mask {
+        permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]}
+      : vector<3x4xf32>, memref<?x?xf32, strided<[?, ?], offset: ?>>
+  return
+}
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d1)>
+//       CHECK: func @fold_masked_vector_transfer_write_with_rank_reducing_subview
+//  CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>
+//  CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: vector<3x4xf32>
+//  CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG3:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG4:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG5:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG6:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG7:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[MASK:[a-zA-Z0-9]+]]: vector<4x3xi1>
+//   CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:    %[[IDX0:.+]] = affine.apply #[[MAP0]]()[%[[ARG2]], %[[ARG6]]]
+//   CHECK-DAG:    %[[IDX1:.+]] = affine.apply #[[MAP0]]()[%[[ARG3]], %[[ARG7]]]
+//   CHECK-DAG:    vector.transfer_write %[[ARG1]], %[[ARG0]][%[[C0]], %[[IDX0]], %[[C0]], %[[IDX1]]], %[[ARG8]] {in_bounds = [true, true], permutation_map = #[[MAP1]]} : vector<3x4xf32>, memref<?x?x?x?xf32
+
+// -----
+
 //  Test with affine.load/store ops. We only do a basic test here since the
 //  logic is identical to that with memref.load/store ops. The same affine.apply
 //  ops would be generated.


        


More information about the Mlir-commits mailing list