[Mlir-commits] [mlir] [mlir][vector] Add pattern to distribute masked reads (PR #71610)

Quinn Dawkins llvmlistbot at llvm.org
Tue Nov 7 16:19:24 PST 2023


https://github.com/qedawkins created https://github.com/llvm/llvm-project/pull/71610

Because the distribution is based on types, supporting general masked reads requires first materializing the permutation map in IR to align the elements of the mask with the elements read by the transfer op. For now just support cases with the trivial permutation map.

>From c4810393c1da4fd2381d13cf044bf6fcdbb24f53 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Sun, 5 Nov 2023 14:38:17 -0500
Subject: [PATCH] [MLIR][Vector] Add pattern to distribute masked reads

Because the distribution is based on types, supporting general masked
reads requires first materializing the permutation map in IR to align
the elements of the mask with the elements read by the transfer op. For
now just support cases with the trivial permutation map.
---
 .../Vector/Transforms/VectorDistribute.cpp    | 41 +++++++++++++++----
 .../Vector/vector-warp-distribute.mlir        | 30 ++++++++++++++
 2 files changed, 63 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e128cc71a5d628c..15dbec1501a2f84 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -818,15 +818,39 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
     auto distributedType = cast<VectorType>(distributedVal.getType());
     AffineMap map = calculateImplicitMap(sequentialType, distributedType);
     AffineMap indexMap = map.compose(read.getPermutationMap());
+
+    // Distribute the mask if present.
     OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPointAfter(warpOp);
+    WarpExecuteOnLane0Op newWarpOp = warpOp;
+    Value newMask = read.getMask();
+    if (read.getMask()) {
+      // TODO: Distribution of masked reads with non-trivial permutation maps
+      // requires the distribution of the mask to elementwise match the
+      // distribution of the permuted written vector. Currently the details
+      // of which lane is responsible for which element is captured strictly
+      // by shape information on the warp op, and thus requires materializing
+      // the permutation in IR.
+      if (!read.getPermutationMap().isMinorIdentity())
+        return failure();
+      VectorType maskType =
+          getDistributedType(read.getMask().getType().cast<VectorType>(), map,
+                             warpOp.getWarpSize());
+      SmallVector<size_t> newRetIndices;
+      newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+          rewriter, warpOp, ValueRange{read.getMask()}, TypeRange{maskType},
+          newRetIndices);
+      newMask = newWarpOp.getResult(newRetIndices[0]);
+      distributedVal = newWarpOp.getResult(operandIndex);
+    }
+
+    rewriter.setInsertionPointAfter(newWarpOp);
 
     // Try to delinearize the lane ID to match the rank expected for
     // distribution.
     SmallVector<Value> delinearizedIds;
     if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
-                           distributedType.getShape(), warpOp.getWarpSize(),
-                           warpOp.getLaneid(), delinearizedIds))
+                           distributedType.getShape(), newWarpOp.getWarpSize(),
+                           newWarpOp.getLaneid(), delinearizedIds))
       return rewriter.notifyMatchFailure(
           read, "cannot delinearize lane ID for distribution");
     assert(!delinearizedIds.empty() || map.getNumResults() == 0);
@@ -846,7 +870,7 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
     }
     auto newRead = rewriter.create<vector::TransferReadOp>(
         read.getLoc(), distributedVal.getType(), read.getSource(), indices,
-        read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
+        read.getPermutationMapAttr(), read.getPadding(), newMask,
         read.getInBoundsAttr());
 
     // Check that the produced operation is legal.
@@ -854,18 +878,19 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
     // warpOp's body, which is illegal.
     // We do the check late because incdices may be changed by
     // makeComposeAffineApply. This rewrite may remove dependencies from
-    // warOp's body.
-    // E.g., warop {
+    // warpOp's body.
+    // E.g., warpop {
     //   %idx = affine.apply...[%outsideDef]
     //   ... = transfer_read ...[%idx]
     // }
     // will be rewritten in:
-    // warop {
+    // warpop {
     // }
     //  %new_idx = affine.apply...[%outsideDef]
     //   ... = transfer_read ...[%new_idx]
     if (!llvm::all_of(newRead->getOperands(), [&](Value value) {
-          return warpOp.isDefinedOutsideOfRegion(value);
+          return (newRead.getMask() && value == newRead.getMask()) ||
+                 newWarpOp.isDefinedOutsideOfRegion(value);
         }))
       return failure();
 
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index f050bcd246e5ef7..2a1007fbbe86435 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1281,3 +1281,33 @@ func.func @warp_propagate_masked_write(%laneid: index, %dest: memref<4096xf32>)
 //       CHECK-DIST-AND-PROP:   }
 //       CHECK-DIST-AND-PROP:   vector.transfer_write %[[W]]#2, {{.*}}, %[[W]]#3 {in_bounds = [true]} : vector<128xf32>, memref<4096xf32>
 //       CHECK-DIST-AND-PROP:   vector.transfer_write %[[W]]#0, {{.*}}, %[[W]]#1 {in_bounds = [true]} : vector<1xf32>, memref<4096xf32>
+
+// -----
+
+func.func @warp_propagate_masked_transfer_read(%laneid: index, %src: memref<4096x4096xf32>, %index: index) -> (vector<2xf32>, vector<2x2xf32>) {
+  %f0 = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %r:2 = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<2xf32>, vector<2x2xf32>) {
+    %mask = "mask_def_0"() : () -> (vector<128xi1>)
+    %0 = vector.transfer_read %src[%c0, %index], %f0, %mask {in_bounds = [true]} : memref<4096x4096xf32>, vector<128xf32>
+    %mask2 = "mask_def_1"() : () -> (vector<128x2xi1>)
+    %1 = vector.transfer_read %src[%c0, %index], %f0, %mask2 {in_bounds = [true, true]} : memref<4096x4096xf32>, vector<128x2xf32>
+    vector.yield %0, %1 : vector<128xf32>, vector<128x2xf32>
+  }
+  return %r#0, %r#1 : vector<2xf32>, vector<2x2xf32>
+}
+
+//   CHECK-PROP-DAG: #[[$MAP0:.+]] = affine_map<()[s0] -> (s0 * 2)>
+//   CHECK-PROP-DAG: #[[$MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 2)>
+// CHECK-PROP-LABEL: func.func @warp_propagate_masked_transfer_read
+//  CHECK-PROP-SAME:   %[[ARG0:.+]]: index, {{.*}}, %[[ARG2:.+]]: index
+//       CHECK-PROP:   %[[C0:.*]] = arith.constant 0 : index
+//       CHECK-PROP:   %[[R:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[64] -> (vector<2xi1>, vector<2x2xi1>) {
+//       CHECK-PROP:     %[[M0:.*]] = "mask_def_0"
+//       CHECK-PROP:     %[[M1:.*]] = "mask_def_1"
+//       CHECK-PROP:     vector.yield %[[M0]], %[[M1]] : vector<128xi1>, vector<128x2xi1>
+//       CHECK-PROP:   }
+//       CHECK-PROP:   %[[DIST_READ_IDX0:.+]] = affine.apply #[[$MAP0]]()[%[[ARG0]]]
+//       CHECK-PROP:   vector.transfer_read {{.*}}[%[[DIST_READ_IDX0]], %[[ARG2]]], {{.*}}, %[[R]]#1 {{.*}} vector<2x2xf32>
+//       CHECK-PROP:   %[[DIST_READ_IDX1:.+]] = affine.apply #[[$MAP1]]()[%[[ARG2]], %[[ARG0]]]
+//       CHECK-PROP:   vector.transfer_read {{.*}}[%[[C0]], %[[DIST_READ_IDX1]]], {{.*}}, %[[R]]#0 {{.*}} vector<2xf32>



More information about the Mlir-commits mailing list