[Mlir-commits] [mlir] [mlir][memref] Add memref alias folding for masked transfers (PR #71476)
Quinn Dawkins
llvmlistbot at llvm.org
Tue Nov 7 05:40:15 PST 2023
https://github.com/qedawkins updated https://github.com/llvm/llvm-project/pull/71476
>From 89cbc80242c2598c610d7e6632a4cadeb7e97e10 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Mon, 6 Nov 2023 09:03:48 -0500
Subject: [PATCH 1/2] [MLIR][MemRef] Add memref alias folding for masked
transfers
Because masking of vector.transfer ops semantically apply to the
unpermuted input vector (for reads) and permuted output vector (for
writes), they apply independently of any subviews.
---
.../MemRef/Transforms/FoldMemRefAliasOps.cpp | 6 +-
.../Dialect/MemRef/fold-memref-alias-ops.mlir | 57 +++++++++++++++++++
2 files changed, 59 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 043e8fbcdd2f6fb..b78c4510ff88585 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -346,8 +346,6 @@ preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp,
"must be a vector transfer op");
if (xferOp.hasOutOfBoundsDim())
return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
- if (xferOp.getMask())
- return rewriter.notifyMatchFailure(xferOp, "masked transfer");
if (!subviewOp.hasUnitStride()) {
return rewriter.notifyMatchFailure(
xferOp, "non-1 stride subview, need to track strides in folded memref");
@@ -428,7 +426,7 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
AffineMapAttr::get(expandDimsToRank(
op.getPermutationMap(), subViewOp.getSourceType().getRank(),
subViewOp.getDroppedDims())),
- op.getPadding(), /*mask=*/Value(), op.getInBoundsAttr());
+ op.getPadding(), op.getMask(), op.getInBoundsAttr());
})
.Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>(
@@ -557,7 +555,7 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
AffineMapAttr::get(expandDimsToRank(
op.getPermutationMap(), subViewOp.getSourceType().getRank(),
subViewOp.getDroppedDims())),
- op.getInBoundsAttr());
+ op.getMask(), op.getInBoundsAttr());
})
.Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>(
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 3f11e22749bb16d..2e1319420fb3eaf 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -266,6 +266,63 @@ func.func @fold_vector_transfer_write_with_inner_rank_reduced_subview(
// -----
+func.func @fold_masked_vector_transfer_read_with_subview(
+ %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
+ %arg1: index, %arg2 : index, %arg3 : index, %arg4: index, %arg5 : index,
+ %arg6 : index, %mask : vector<4xi1>) -> vector<4xf32> {
+ %cst = arith.constant 0.0 : f32
+ %0 = memref.subview %arg0[%arg1, %arg2] [%arg3, %arg4] [1, 1]
+ : memref<?x?xf32, strided<[?, ?], offset: ?>> to
+ memref<?x?xf32, strided<[?, ?], offset: ?>>
+ %1 = vector.transfer_read %0[%arg5, %arg6], %cst, %mask {in_bounds = [true]}
+ : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4xf32>
+ return %1 : vector<4xf32>
+}
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK: func @fold_masked_vector_transfer_read_with_subview
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32, strided<[?, ?], offset: ?>>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: vector<4xi1>
+// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG1]], %[[ARG5]]]
+// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
+// CHECK: vector.transfer_read %[[ARG0]][%[[IDX0]], %[[IDX1]]], %{{.*}}, %[[ARG7]] {{.*}} : memref<?x?xf32
+
+// -----
+
+func.func @fold_masked_vector_transfer_write_with_subview(
+ %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
+ %arg1 : vector<4xf32>, %arg2: index, %arg3 : index, %arg4 : index,
+ %arg5: index, %arg6 : index, %arg7 : index, %mask : vector<4xi1>) {
+ %cst = arith.constant 0.0 : f32
+ %0 = memref.subview %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1]
+ : memref<?x?xf32, strided<[?, ?], offset: ?>> to
+ memref<?x?xf32, strided<[?, ?], offset: ?>>
+ vector.transfer_write %arg1, %0[%arg6, %arg7], %mask {in_bounds = [true]}
+ : vector<4xf32>, memref<?x?xf32, strided<[?, ?], offset: ?>>
+ return
+}
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK: func @fold_masked_vector_transfer_write_with_subview
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32, strided<[?, ?], offset: ?>>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: vector<4xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: vector<4xi1>
+// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
+// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG7]]]
+// CHECK-DAG: vector.transfer_write %[[ARG1]], %[[ARG0]][%[[IDX0]], %[[IDX1]]], %[[ARG8]] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32
+
+// -----
+
// Test with affine.load/store ops. We only do a basic test here since the
// logic is identical to that with memref.load/store ops. The same affine.apply
// ops would be generated.
>From b421c2627fd5182de08dd3f6c5191ac04ee03c1e Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Tue, 7 Nov 2023 08:39:55 -0500
Subject: [PATCH 2/2] Add rank reducing masked transfer tests
---
.../Dialect/MemRef/fold-memref-alias-ops.mlir | 72 +++++++++++++++++--
1 file changed, 68 insertions(+), 4 deletions(-)
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 2e1319420fb3eaf..8fe87bc8c57c300 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -287,10 +287,42 @@ func.func @fold_masked_vector_transfer_read_with_subview(
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
-// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: vector<4xi1>
+// CHECK-SAME: %[[MASK:[a-zA-Z0-9]+]]: vector<4xi1>
// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG1]], %[[ARG5]]]
// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
-// CHECK: vector.transfer_read %[[ARG0]][%[[IDX0]], %[[IDX1]]], %{{.*}}, %[[ARG7]] {{.*}} : memref<?x?xf32
+// CHECK: vector.transfer_read %[[ARG0]][%[[IDX0]], %[[IDX1]]], %{{.*}}, %[[MASK]] {{.*}} : memref<?x?xf32
+
+// -----
+
+func.func @fold_masked_vector_transfer_read_with_rank_reducing_subview(
+ %arg0 : memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>,
+ %arg1: index, %arg2 : index, %arg3 : index, %arg4: index, %arg5 : index,
+ %arg6 : index, %mask : vector<4x3xi1>) -> vector<3x4xf32> {
+ %cst = arith.constant 0.0 : f32
+ %0 = memref.subview %arg0[0, %arg1, 0, %arg2] [1, %arg3, 1, %arg4] [1, 1, 1, 1]
+ : memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>> to
+ memref<?x?xf32, strided<[?, ?], offset: ?>>
+ %1 = vector.transfer_read %0[%arg5, %arg6], %cst, %mask {
+ permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]}
+ : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<3x4xf32>
+ return %1 : vector<3x4xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d1)>
+// CHECK: func @fold_masked_vector_transfer_read_with_rank_reducing_subview
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[MASK:[a-zA-Z0-9]+]]: vector<4x3xi1>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[PAD:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[ARG5]]]
+// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP0]]()[%[[ARG2]], %[[ARG6]]]
+// CHECK: vector.transfer_read %[[ARG0]][%[[C0]], %[[IDX0]], %[[C0]], %[[IDX1]]], %[[PAD]], %[[MASK]] {{.*}} permutation_map = #[[MAP1]]} : memref<?x?x?x?xf32
// -----
@@ -316,10 +348,42 @@ func.func @fold_masked_vector_transfer_write_with_subview(
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index
-// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: vector<4xi1>
+// CHECK-SAME: %[[MASK:[a-zA-Z0-9]+]]: vector<4xi1>
// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG7]]]
-// CHECK-DAG: vector.transfer_write %[[ARG1]], %[[ARG0]][%[[IDX0]], %[[IDX1]]], %[[ARG8]] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32
+// CHECK-DAG: vector.transfer_write %[[ARG1]], %[[ARG0]][%[[IDX0]], %[[IDX1]]], %[[MASK]] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32
+
+// -----
+
+func.func @fold_masked_vector_transfer_write_with_rank_reducing_subview(
+ %arg0 : memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>,
+ %arg1 : vector<3x4xf32>, %arg2: index, %arg3 : index, %arg4 : index,
+ %arg5: index, %arg6 : index, %arg7 : index, %mask : vector<4x3xi1>) {
+ %cst = arith.constant 0.0 : f32
+ %0 = memref.subview %arg0[0, %arg2, 0, %arg3] [1, %arg4, 1, %arg5] [1, 1, 1, 1]
+ : memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>> to
+ memref<?x?xf32, strided<[?, ?], offset: ?>>
+ vector.transfer_write %arg1, %0[%arg6, %arg7], %mask {
+ permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]}
+ : vector<3x4xf32>, memref<?x?xf32, strided<[?, ?], offset: ?>>
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d1)>
+// CHECK: func @fold_masked_vector_transfer_write_with_rank_reducing_subview
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: vector<3x4xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[MASK:[a-zA-Z0-9]+]]: vector<4x3xi1>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP0]]()[%[[ARG2]], %[[ARG6]]]
+// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP0]]()[%[[ARG3]], %[[ARG7]]]
+// CHECK-DAG: vector.transfer_write %[[ARG1]], %[[ARG0]][%[[C0]], %[[IDX0]], %[[C0]], %[[IDX1]]], %[[ARG8]] {in_bounds = [true, true], permutation_map = #[[MAP1]]} : vector<3x4xf32>, memref<?x?x?x?xf32
// -----
More information about the Mlir-commits
mailing list