[Mlir-commits] [mlir] 48f980c - [mlir][memref] Add memref alias folding for masked transfers (#71476)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 7 05:56:59 PST 2023
Author: Quinn Dawkins
Date: 2023-11-07T08:56:54-05:00
New Revision: 48f980c535ca9a4e2c00c56734e6d3346f4b0a86
URL: https://github.com/llvm/llvm-project/commit/48f980c535ca9a4e2c00c56734e6d3346f4b0a86
DIFF: https://github.com/llvm/llvm-project/commit/48f980c535ca9a4e2c00c56734e6d3346f4b0a86.diff
LOG: [mlir][memref] Add memref alias folding for masked transfers (#71476)
The contents of a mask on a masked transfer are unaffected by the
particular region of memory being read/stored to, so just forward the
mask in subview folding patterns.
Added:
Modified:
mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 043e8fbcdd2f6fb..b78c4510ff88585 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -346,8 +346,6 @@ preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp,
"must be a vector transfer op");
if (xferOp.hasOutOfBoundsDim())
return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
- if (xferOp.getMask())
- return rewriter.notifyMatchFailure(xferOp, "masked transfer");
if (!subviewOp.hasUnitStride()) {
return rewriter.notifyMatchFailure(
xferOp, "non-1 stride subview, need to track strides in folded memref");
@@ -428,7 +426,7 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
AffineMapAttr::get(expandDimsToRank(
op.getPermutationMap(), subViewOp.getSourceType().getRank(),
subViewOp.getDroppedDims())),
- op.getPadding(), /*mask=*/Value(), op.getInBoundsAttr());
+ op.getPadding(), op.getMask(), op.getInBoundsAttr());
})
.Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>(
@@ -557,7 +555,7 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
AffineMapAttr::get(expandDimsToRank(
op.getPermutationMap(), subViewOp.getSourceType().getRank(),
subViewOp.getDroppedDims())),
- op.getInBoundsAttr());
+ op.getMask(), op.getInBoundsAttr());
})
.Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>(
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 3f11e22749bb16d..8fe87bc8c57c300 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -266,6 +266,127 @@ func.func @fold_vector_transfer_write_with_inner_rank_reduced_subview(
// -----
+func.func @fold_masked_vector_transfer_read_with_subview(
+ %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
+ %arg1: index, %arg2 : index, %arg3 : index, %arg4: index, %arg5 : index,
+ %arg6 : index, %mask : vector<4xi1>) -> vector<4xf32> {
+ %cst = arith.constant 0.0 : f32
+ %0 = memref.subview %arg0[%arg1, %arg2] [%arg3, %arg4] [1, 1]
+ : memref<?x?xf32, strided<[?, ?], offset: ?>> to
+ memref<?x?xf32, strided<[?, ?], offset: ?>>
+ %1 = vector.transfer_read %0[%arg5, %arg6], %cst, %mask {in_bounds = [true]}
+ : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4xf32>
+ return %1 : vector<4xf32>
+}
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK: func @fold_masked_vector_transfer_read_with_subview
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32, strided<[?, ?], offset: ?>>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[MASK:[a-zA-Z0-9]+]]: vector<4xi1>
+// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG1]], %[[ARG5]]]
+// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
+// CHECK: vector.transfer_read %[[ARG0]][%[[IDX0]], %[[IDX1]]], %{{.*}}, %[[MASK]] {{.*}} : memref<?x?xf32
+
+// -----
+
+func.func @fold_masked_vector_transfer_read_with_rank_reducing_subview(
+ %arg0 : memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>,
+ %arg1: index, %arg2 : index, %arg3 : index, %arg4: index, %arg5 : index,
+ %arg6 : index, %mask : vector<4x3xi1>) -> vector<3x4xf32> {
+ %cst = arith.constant 0.0 : f32
+ %0 = memref.subview %arg0[0, %arg1, 0, %arg2] [1, %arg3, 1, %arg4] [1, 1, 1, 1]
+ : memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>> to
+ memref<?x?xf32, strided<[?, ?], offset: ?>>
+ %1 = vector.transfer_read %0[%arg5, %arg6], %cst, %mask {
+ permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]}
+ : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<3x4xf32>
+ return %1 : vector<3x4xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d1)>
+// CHECK: func @fold_masked_vector_transfer_read_with_rank_reducing_subview
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[MASK:[a-zA-Z0-9]+]]: vector<4x3xi1>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[PAD:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[ARG5]]]
+// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP0]]()[%[[ARG2]], %[[ARG6]]]
+// CHECK: vector.transfer_read %[[ARG0]][%[[C0]], %[[IDX0]], %[[C0]], %[[IDX1]]], %[[PAD]], %[[MASK]] {{.*}} permutation_map = #[[MAP1]]} : memref<?x?x?x?xf32
+
+// -----
+
+func.func @fold_masked_vector_transfer_write_with_subview(
+ %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
+ %arg1 : vector<4xf32>, %arg2: index, %arg3 : index, %arg4 : index,
+ %arg5: index, %arg6 : index, %arg7 : index, %mask : vector<4xi1>) {
+ %cst = arith.constant 0.0 : f32
+ %0 = memref.subview %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1]
+ : memref<?x?xf32, strided<[?, ?], offset: ?>> to
+ memref<?x?xf32, strided<[?, ?], offset: ?>>
+ vector.transfer_write %arg1, %0[%arg6, %arg7], %mask {in_bounds = [true]}
+ : vector<4xf32>, memref<?x?xf32, strided<[?, ?], offset: ?>>
+ return
+}
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK: func @fold_masked_vector_transfer_write_with_subview
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32, strided<[?, ?], offset: ?>>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: vector<4xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[MASK:[a-zA-Z0-9]+]]: vector<4xi1>
+// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
+// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG7]]]
+// CHECK-DAG: vector.transfer_write %[[ARG1]], %[[ARG0]][%[[IDX0]], %[[IDX1]]], %[[MASK]] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32
+
+// -----
+
+func.func @fold_masked_vector_transfer_write_with_rank_reducing_subview(
+ %arg0 : memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>,
+ %arg1 : vector<3x4xf32>, %arg2: index, %arg3 : index, %arg4 : index,
+ %arg5: index, %arg6 : index, %arg7 : index, %mask : vector<4x3xi1>) {
+ %cst = arith.constant 0.0 : f32
+ %0 = memref.subview %arg0[0, %arg2, 0, %arg3] [1, %arg4, 1, %arg5] [1, 1, 1, 1]
+ : memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>> to
+ memref<?x?xf32, strided<[?, ?], offset: ?>>
+ vector.transfer_write %arg1, %0[%arg6, %arg7], %mask {
+ permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]}
+ : vector<3x4xf32>, memref<?x?xf32, strided<[?, ?], offset: ?>>
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d1)>
+// CHECK: func @fold_masked_vector_transfer_write_with_rank_reducing_subview
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: vector<3x4xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[MASK:[a-zA-Z0-9]+]]: vector<4x3xi1>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP0]]()[%[[ARG2]], %[[ARG6]]]
+// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP0]]()[%[[ARG3]], %[[ARG7]]]
+// CHECK-DAG: vector.transfer_write %[[ARG1]], %[[ARG0]][%[[C0]], %[[IDX0]], %[[C0]], %[[IDX1]]], %[[ARG8]] {in_bounds = [true, true], permutation_map = #[[MAP1]]} : vector<3x4xf32>, memref<?x?x?x?xf32
+
+// -----
+
// Test with affine.load/store ops. We only do a basic test here since the
// logic is identical to that with memref.load/store ops. The same affine.apply
// ops would be generated.
More information about the Mlir-commits
mailing list