[Mlir-commits] [mlir] 25ec1fa - [mlir][vector] Add support for distributing masked writes (#71482)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 7 14:54:53 PST 2023


Author: Quinn Dawkins
Date: 2023-11-07T17:54:49-05:00
New Revision: 25ec1fa969a0d13f440222f575277f9601eaea76

URL: https://github.com/llvm/llvm-project/commit/25ec1fa969a0d13f440222f575277f9601eaea76
DIFF: https://github.com/llvm/llvm-project/commit/25ec1fa969a0d13f440222f575277f9601eaea76.diff

LOG: [mlir][vector] Add support for distributing masked writes (#71482)

General distribution of masked writes requires materializing the permutation on the vector of the write in IR to ensure the vector lines up with the mask. For now just support cases with trivial permutation maps.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 78015e3deeb967e..e128cc71a5d628c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -406,19 +406,29 @@ struct WarpOpToScfIfPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
 static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
                                             WarpExecuteOnLane0Op warpOp,
                                             vector::TransferWriteOp writeOp,
-                                            VectorType targetType) {
+                                            VectorType targetType,
+                                            VectorType maybeMaskType) {
   assert(writeOp->getParentOp() == warpOp &&
          "write must be nested immediately under warp");
   OpBuilder::InsertionGuard g(rewriter);
   SmallVector<size_t> newRetIndices;
-  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-      rewriter, warpOp, ValueRange{{writeOp.getVector()}},
-      TypeRange{targetType}, newRetIndices);
+  WarpExecuteOnLane0Op newWarpOp;
+  if (maybeMaskType) {
+    newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()},
+        TypeRange{targetType, maybeMaskType}, newRetIndices);
+  } else {
+    newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, ValueRange{{writeOp.getVector()}},
+        TypeRange{targetType}, newRetIndices);
+  }
   rewriter.setInsertionPointAfter(newWarpOp);
   auto newWriteOp =
       cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
   rewriter.eraseOp(writeOp);
   newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
+  if (maybeMaskType)
+    newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
   return newWriteOp;
 }
 
@@ -489,10 +499,25 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
     if (!targetType)
       return failure();
 
+    // 2.5 Compute the distributed type for the new mask;
+    VectorType maskType;
+    if (writeOp.getMask()) {
+      // TODO: Distribution of masked writes with non-trivial permutation maps
+      // requires the distribution of the mask to elementwise match the
+      // distribution of the permuted written vector. Currently the details
+      // of which lane is responsible for which element is captured strictly
+      // by shape information on the warp op, and thus requires materializing
+      // the permutation in IR.
+      if (!writeOp.getPermutationMap().isMinorIdentity())
+        return failure();
+      maskType =
+          getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
+    }
+
     // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
     // the rest.
     vector::TransferWriteOp newWriteOp =
-        cloneWriteOp(rewriter, warpOp, writeOp, targetType);
+        cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
 
     // 4. Reindex the write using the distribution map.
     auto newWarpOp =
@@ -561,10 +586,6 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
 
   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
                                 PatternRewriter &rewriter) const override {
-    // Ops with mask not supported yet.
-    if (writeOp.getMask())
-      return failure();
-
     auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp());
     if (!warpOp)
       return failure();
@@ -575,8 +596,10 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
       if (!isMemoryEffectFree(nextOp))
         return failure();
 
+    Value maybeMask = writeOp.getMask();
     if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
           return writeOp.getVector() == value ||
+                 (maybeMask && maybeMask == value) ||
                  warpOp.isDefinedOutsideOfRegion(value);
         }))
       return failure();
@@ -584,6 +607,10 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
     if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
       return success();
 
+    // Masked writes not supported for extraction.
+    if (writeOp.getMask())
+      return failure();
+
     if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
       return success();
 

diff  --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 5ec02ce002ffbd6..f050bcd246e5ef7 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1253,3 +1253,31 @@ func.func @warp_propagate_uniform_transfer_read(%laneid: index, %src: memref<409
 //  CHECK-PROP-SAME: (%{{.+}}: index, %[[SRC:.+]]: memref<4096xf32>, %[[INDEX:.+]]: index)
 //       CHECK-PROP:   %[[READ:.+]] = vector.transfer_read %[[SRC]][%[[INDEX]]], %cst {in_bounds = [true]} : memref<4096xf32>, vector<1xf32>
 //       CHECK-PROP:   return %[[READ]] : vector<1xf32>
+
+// -----
+
+func.func @warp_propagate_masked_write(%laneid: index, %dest: memref<4096xf32>) {
+  %c0 = arith.constant 0 : index
+  vector.warp_execute_on_lane_0(%laneid)[32] -> () {
+    %mask = "mask_def_0"() : () -> (vector<4096xi1>)
+    %mask2 = "mask_def_1"() : () -> (vector<32xi1>)
+    %0 = "some_def_0"() : () -> (vector<4096xf32>)
+    %1 = "some_def_1"() : () -> (vector<32xf32>)
+    vector.transfer_write %0, %dest[%c0], %mask : vector<4096xf32>, memref<4096xf32>
+    vector.transfer_write %1, %dest[%c0], %mask2 : vector<32xf32>, memref<4096xf32>
+    vector.yield
+  }
+  return
+}
+
+// CHECK-DIST-AND-PROP-LABEL: func.func @warp_propagate_masked_write(
+//       CHECK-DIST-AND-PROP:   %[[W:.*]]:4 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>, vector<1xi1>, vector<128xf32>, vector<128xi1>) {
+//       CHECK-DIST-AND-PROP:     %[[M0:.*]] = "mask_def_0"
+//       CHECK-DIST-AND-PROP:     %[[M1:.*]] = "mask_def_1"
+//       CHECK-DIST-AND-PROP:     %[[V0:.*]] = "some_def_0"
+//       CHECK-DIST-AND-PROP:     %[[V1:.*]] = "some_def_1"
+//       CHECK-DIST-AND-PROP:     vector.yield %[[V1]], %[[M1]], %[[V0]], %[[M0]]
+//  CHECK-DIST-AND-PROP-SAME:       vector<32xf32>, vector<32xi1>, vector<4096xf32>, vector<4096xi1>
+//       CHECK-DIST-AND-PROP:   }
+//       CHECK-DIST-AND-PROP:   vector.transfer_write %[[W]]#2, {{.*}}, %[[W]]#3 {in_bounds = [true]} : vector<128xf32>, memref<4096xf32>
+//       CHECK-DIST-AND-PROP:   vector.transfer_write %[[W]]#0, {{.*}}, %[[W]]#1 {in_bounds = [true]} : vector<1xf32>, memref<4096xf32>


        


More information about the Mlir-commits mailing list