[Mlir-commits] [mlir] [MLIR][Vector] Add support for distributing masked writes (PR #71482)
Quinn Dawkins
llvmlistbot at llvm.org
Mon Nov 6 19:54:15 PST 2023
https://github.com/qedawkins created https://github.com/llvm/llvm-project/pull/71482
Because the mask applies to the un-permuted write vector, we can simply distribute the mask identically to the vector, if present.
>From 0a720ee8347c743ea95c4bc9e5737f7f0e71efc9 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Sun, 5 Nov 2023 10:22:28 -0500
Subject: [PATCH] [MLIR][Vector] Add support for distributing masked writes
Because the mask applies to the un-permuted write vector, we can
simply distribute the mask identically to the vector, if present.
---
.../Vector/Transforms/VectorDistribute.cpp | 38 ++++++++++++++-----
.../Vector/vector-warp-distribute.mlir | 28 ++++++++++++++
2 files changed, 57 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 78015e3deeb967e..bbc28e64bbfd8ac 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -406,19 +406,29 @@ struct WarpOpToScfIfPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
WarpExecuteOnLane0Op warpOp,
vector::TransferWriteOp writeOp,
- VectorType targetType) {
+ VectorType targetType,
+ VectorType maybeMaskType) {
assert(writeOp->getParentOp() == warpOp &&
"write must be nested immediately under warp");
OpBuilder::InsertionGuard g(rewriter);
SmallVector<size_t> newRetIndices;
- WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, warpOp, ValueRange{{writeOp.getVector()}},
- TypeRange{targetType}, newRetIndices);
+ WarpExecuteOnLane0Op newWarpOp;
+ if (maybeMaskType) {
+ newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()},
+ TypeRange{targetType, maybeMaskType}, newRetIndices);
+ } else {
+ newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, ValueRange{{writeOp.getVector()}},
+ TypeRange{targetType}, newRetIndices);
+ }
rewriter.setInsertionPointAfter(newWarpOp);
auto newWriteOp =
cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
rewriter.eraseOp(writeOp);
newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
+ if (maybeMaskType)
+ newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
return newWriteOp;
}
@@ -489,10 +499,18 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
if (!targetType)
return failure();
+ // 2.5 Compute the distributed type for the new mask;
+ VectorType maskType;
+ if (writeOp.getMask()) {
+ maskType =
+ getDistributedType(writeOp.getMask().getType().cast<VectorType>(),
+ map, warpOp.getWarpSize());
+ }
+
// 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
// the rest.
vector::TransferWriteOp newWriteOp =
- cloneWriteOp(rewriter, warpOp, writeOp, targetType);
+ cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
// 4. Reindex the write using the distribution map.
auto newWarpOp =
@@ -561,10 +579,6 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
PatternRewriter &rewriter) const override {
- // Ops with mask not supported yet.
- if (writeOp.getMask())
- return failure();
-
auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp());
if (!warpOp)
return failure();
@@ -575,8 +589,10 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
if (!isMemoryEffectFree(nextOp))
return failure();
+ Value maybeMask = writeOp.getMask();
if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
return writeOp.getVector() == value ||
+ (maybeMask && maybeMask == value) ||
warpOp.isDefinedOutsideOfRegion(value);
}))
return failure();
@@ -584,6 +600,10 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
return success();
+ // Masked writes not supported for extraction.
+ if (writeOp.getMask())
+ return failure();
+
if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
return success();
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 5ec02ce002ffbd6..f050bcd246e5ef7 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1253,3 +1253,31 @@ func.func @warp_propagate_uniform_transfer_read(%laneid: index, %src: memref<409
// CHECK-PROP-SAME: (%{{.+}}: index, %[[SRC:.+]]: memref<4096xf32>, %[[INDEX:.+]]: index)
// CHECK-PROP: %[[READ:.+]] = vector.transfer_read %[[SRC]][%[[INDEX]]], %cst {in_bounds = [true]} : memref<4096xf32>, vector<1xf32>
// CHECK-PROP: return %[[READ]] : vector<1xf32>
+
+// -----
+
+func.func @warp_propagate_masked_write(%laneid: index, %dest: memref<4096xf32>) {
+ %c0 = arith.constant 0 : index
+ vector.warp_execute_on_lane_0(%laneid)[32] -> () {
+ %mask = "mask_def_0"() : () -> (vector<4096xi1>)
+ %mask2 = "mask_def_1"() : () -> (vector<32xi1>)
+ %0 = "some_def_0"() : () -> (vector<4096xf32>)
+ %1 = "some_def_1"() : () -> (vector<32xf32>)
+ vector.transfer_write %0, %dest[%c0], %mask : vector<4096xf32>, memref<4096xf32>
+ vector.transfer_write %1, %dest[%c0], %mask2 : vector<32xf32>, memref<4096xf32>
+ vector.yield
+ }
+ return
+}
+
+// CHECK-DIST-AND-PROP-LABEL: func.func @warp_propagate_masked_write(
+// CHECK-DIST-AND-PROP: %[[W:.*]]:4 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>, vector<1xi1>, vector<128xf32>, vector<128xi1>) {
+// CHECK-DIST-AND-PROP: %[[M0:.*]] = "mask_def_0"
+// CHECK-DIST-AND-PROP: %[[M1:.*]] = "mask_def_1"
+// CHECK-DIST-AND-PROP: %[[V0:.*]] = "some_def_0"
+// CHECK-DIST-AND-PROP: %[[V1:.*]] = "some_def_1"
+// CHECK-DIST-AND-PROP: vector.yield %[[V1]], %[[M1]], %[[V0]], %[[M0]]
+// CHECK-DIST-AND-PROP-SAME: vector<32xf32>, vector<32xi1>, vector<4096xf32>, vector<4096xi1>
+// CHECK-DIST-AND-PROP: }
+// CHECK-DIST-AND-PROP: vector.transfer_write %[[W]]#2, {{.*}}, %[[W]]#3 {in_bounds = [true]} : vector<128xf32>, memref<4096xf32>
+// CHECK-DIST-AND-PROP: vector.transfer_write %[[W]]#0, {{.*}}, %[[W]]#1 {in_bounds = [true]} : vector<1xf32>, memref<4096xf32>
More information about the Mlir-commits
mailing list