[Mlir-commits] [mlir] [MLIR][Vector] Add support for distributing masked writes (PR #71482)

Quinn Dawkins llvmlistbot at llvm.org
Mon Nov 6 19:54:15 PST 2023


https://github.com/qedawkins created https://github.com/llvm/llvm-project/pull/71482

Because the mask applies to the un-permuted write vector, we can simply distribute the mask identically to the vector, if present.

>From 0a720ee8347c743ea95c4bc9e5737f7f0e71efc9 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Sun, 5 Nov 2023 10:22:28 -0500
Subject: [PATCH] [MLIR][Vector] Add support for distributing masked writes

Because the mask applies to the un-permuted write vector, we can
simply distribute the mask identically to the vector, if present.
---
 .../Vector/Transforms/VectorDistribute.cpp    | 38 ++++++++++++++-----
 .../Vector/vector-warp-distribute.mlir        | 28 ++++++++++++++
 2 files changed, 57 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 78015e3deeb967e..bbc28e64bbfd8ac 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -406,19 +406,29 @@ struct WarpOpToScfIfPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
 static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
                                             WarpExecuteOnLane0Op warpOp,
                                             vector::TransferWriteOp writeOp,
-                                            VectorType targetType) {
+                                            VectorType targetType,
+                                            VectorType maybeMaskType) {
   assert(writeOp->getParentOp() == warpOp &&
          "write must be nested immediately under warp");
   OpBuilder::InsertionGuard g(rewriter);
   SmallVector<size_t> newRetIndices;
-  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-      rewriter, warpOp, ValueRange{{writeOp.getVector()}},
-      TypeRange{targetType}, newRetIndices);
+  WarpExecuteOnLane0Op newWarpOp;
+  if (maybeMaskType) {
+    newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()},
+        TypeRange{targetType, maybeMaskType}, newRetIndices);
+  } else {
+    newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, ValueRange{{writeOp.getVector()}},
+        TypeRange{targetType}, newRetIndices);
+  }
   rewriter.setInsertionPointAfter(newWarpOp);
   auto newWriteOp =
       cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
   rewriter.eraseOp(writeOp);
   newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
+  if (maybeMaskType)
+    newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
   return newWriteOp;
 }
 
@@ -489,10 +499,18 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
     if (!targetType)
       return failure();
 
+    // 2.5 Compute the distributed type for the new mask;
+    VectorType maskType;
+    if (writeOp.getMask()) {
+      maskType =
+          getDistributedType(writeOp.getMask().getType().cast<VectorType>(),
+                             map, warpOp.getWarpSize());
+    }
+
     // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
     // the rest.
     vector::TransferWriteOp newWriteOp =
-        cloneWriteOp(rewriter, warpOp, writeOp, targetType);
+        cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
 
     // 4. Reindex the write using the distribution map.
     auto newWarpOp =
@@ -561,10 +579,6 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
 
   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
                                 PatternRewriter &rewriter) const override {
-    // Ops with mask not supported yet.
-    if (writeOp.getMask())
-      return failure();
-
     auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp());
     if (!warpOp)
       return failure();
@@ -575,8 +589,10 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
       if (!isMemoryEffectFree(nextOp))
         return failure();
 
+    Value maybeMask = writeOp.getMask();
     if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
           return writeOp.getVector() == value ||
+                 (maybeMask && maybeMask == value) ||
                  warpOp.isDefinedOutsideOfRegion(value);
         }))
       return failure();
@@ -584,6 +600,10 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
     if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
       return success();
 
+    // Masked writes not supported for extraction.
+    if (writeOp.getMask())
+      return failure();
+
     if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
       return success();
 
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 5ec02ce002ffbd6..f050bcd246e5ef7 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1253,3 +1253,31 @@ func.func @warp_propagate_uniform_transfer_read(%laneid: index, %src: memref<409
 //  CHECK-PROP-SAME: (%{{.+}}: index, %[[SRC:.+]]: memref<4096xf32>, %[[INDEX:.+]]: index)
 //       CHECK-PROP:   %[[READ:.+]] = vector.transfer_read %[[SRC]][%[[INDEX]]], %cst {in_bounds = [true]} : memref<4096xf32>, vector<1xf32>
 //       CHECK-PROP:   return %[[READ]] : vector<1xf32>
+
+// -----
+
+func.func @warp_propagate_masked_write(%laneid: index, %dest: memref<4096xf32>) {
+  %c0 = arith.constant 0 : index
+  vector.warp_execute_on_lane_0(%laneid)[32] -> () {
+    %mask = "mask_def_0"() : () -> (vector<4096xi1>)
+    %mask2 = "mask_def_1"() : () -> (vector<32xi1>)
+    %0 = "some_def_0"() : () -> (vector<4096xf32>)
+    %1 = "some_def_1"() : () -> (vector<32xf32>)
+    vector.transfer_write %0, %dest[%c0], %mask : vector<4096xf32>, memref<4096xf32>
+    vector.transfer_write %1, %dest[%c0], %mask2 : vector<32xf32>, memref<4096xf32>
+    vector.yield
+  }
+  return
+}
+
+// CHECK-DIST-AND-PROP-LABEL: func.func @warp_propagate_masked_write(
+//       CHECK-DIST-AND-PROP:   %[[W:.*]]:4 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>, vector<1xi1>, vector<128xf32>, vector<128xi1>) {
+//       CHECK-DIST-AND-PROP:     %[[M0:.*]] = "mask_def_0"
+//       CHECK-DIST-AND-PROP:     %[[M1:.*]] = "mask_def_1"
+//       CHECK-DIST-AND-PROP:     %[[V0:.*]] = "some_def_0"
+//       CHECK-DIST-AND-PROP:     %[[V1:.*]] = "some_def_1"
+//       CHECK-DIST-AND-PROP:     vector.yield %[[V1]], %[[M1]], %[[V0]], %[[M0]]
+//  CHECK-DIST-AND-PROP-SAME:       vector<32xf32>, vector<32xi1>, vector<4096xf32>, vector<4096xi1>
+//       CHECK-DIST-AND-PROP:   }
+//       CHECK-DIST-AND-PROP:   vector.transfer_write %[[W]]#2, {{.*}}, %[[W]]#3 {in_bounds = [true]} : vector<128xf32>, memref<4096xf32>
+//       CHECK-DIST-AND-PROP:   vector.transfer_write %[[W]]#0, {{.*}}, %[[W]]#1 {in_bounds = [true]} : vector<1xf32>, memref<4096xf32>



More information about the Mlir-commits mailing list