[Mlir-commits] [mlir] [mlir][vector] Add pattern to distribute masked reads (PR #71610)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 7 16:19:53 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Quinn Dawkins (qedawkins)

<details>
<summary>Changes</summary>

Because the distribution is based on types, supporting general masked reads requires first materializing the permutation map in IR to align the elements of the mask with the elements read by the transfer op. For now just support cases with the trivial permutation map.

---
Full diff: https://github.com/llvm/llvm-project/pull/71610.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+33-8) 
- (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+30) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e128cc71a5d628c..15dbec1501a2f84 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -818,15 +818,39 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
     auto distributedType = cast<VectorType>(distributedVal.getType());
     AffineMap map = calculateImplicitMap(sequentialType, distributedType);
     AffineMap indexMap = map.compose(read.getPermutationMap());
+
+    // Distribute the mask if present.
     OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPointAfter(warpOp);
+    WarpExecuteOnLane0Op newWarpOp = warpOp;
+    Value newMask = read.getMask();
+    if (read.getMask()) {
+      // TODO: Distribution of masked reads with non-trivial permutation maps
+      // requires the distribution of the mask to elementwise match the
+      // distribution of the permuted written vector. Currently the details
+      // of which lane is responsible for which element is captured strictly
+      // by shape information on the warp op, and thus requires materializing
+      // the permutation in IR.
+      if (!read.getPermutationMap().isMinorIdentity())
+        return failure();
+      VectorType maskType =
+          getDistributedType(read.getMask().getType().cast<VectorType>(), map,
+                             warpOp.getWarpSize());
+      SmallVector<size_t> newRetIndices;
+      newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+          rewriter, warpOp, ValueRange{read.getMask()}, TypeRange{maskType},
+          newRetIndices);
+      newMask = newWarpOp.getResult(newRetIndices[0]);
+      distributedVal = newWarpOp.getResult(operandIndex);
+    }
+
+    rewriter.setInsertionPointAfter(newWarpOp);
 
     // Try to delinearize the lane ID to match the rank expected for
     // distribution.
     SmallVector<Value> delinearizedIds;
     if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
-                           distributedType.getShape(), warpOp.getWarpSize(),
-                           warpOp.getLaneid(), delinearizedIds))
+                           distributedType.getShape(), newWarpOp.getWarpSize(),
+                           newWarpOp.getLaneid(), delinearizedIds))
       return rewriter.notifyMatchFailure(
           read, "cannot delinearize lane ID for distribution");
     assert(!delinearizedIds.empty() || map.getNumResults() == 0);
@@ -846,7 +870,7 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
     }
     auto newRead = rewriter.create<vector::TransferReadOp>(
         read.getLoc(), distributedVal.getType(), read.getSource(), indices,
-        read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
+        read.getPermutationMapAttr(), read.getPadding(), newMask,
         read.getInBoundsAttr());
 
     // Check that the produced operation is legal.
@@ -854,18 +878,19 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
     // warpOp's body, which is illegal.
     // We do the check late because incdices may be changed by
     // makeComposeAffineApply. This rewrite may remove dependencies from
-    // warOp's body.
-    // E.g., warop {
+    // warpOp's body.
+    // E.g., warpop {
     //   %idx = affine.apply...[%outsideDef]
     //   ... = transfer_read ...[%idx]
     // }
     // will be rewritten in:
-    // warop {
+    // warpop {
     // }
     //  %new_idx = affine.apply...[%outsideDef]
     //   ... = transfer_read ...[%new_idx]
     if (!llvm::all_of(newRead->getOperands(), [&](Value value) {
-          return warpOp.isDefinedOutsideOfRegion(value);
+          return (newRead.getMask() && value == newRead.getMask()) ||
+                 newWarpOp.isDefinedOutsideOfRegion(value);
         }))
       return failure();
 
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index f050bcd246e5ef7..2a1007fbbe86435 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1281,3 +1281,33 @@ func.func @warp_propagate_masked_write(%laneid: index, %dest: memref<4096xf32>)
 //       CHECK-DIST-AND-PROP:   }
 //       CHECK-DIST-AND-PROP:   vector.transfer_write %[[W]]#2, {{.*}}, %[[W]]#3 {in_bounds = [true]} : vector<128xf32>, memref<4096xf32>
 //       CHECK-DIST-AND-PROP:   vector.transfer_write %[[W]]#0, {{.*}}, %[[W]]#1 {in_bounds = [true]} : vector<1xf32>, memref<4096xf32>
+
+// -----
+
+func.func @warp_propagate_masked_transfer_read(%laneid: index, %src: memref<4096x4096xf32>, %index: index) -> (vector<2xf32>, vector<2x2xf32>) {
+  %f0 = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %r:2 = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<2xf32>, vector<2x2xf32>) {
+    %mask = "mask_def_0"() : () -> (vector<128xi1>)
+    %0 = vector.transfer_read %src[%c0, %index], %f0, %mask {in_bounds = [true]} : memref<4096x4096xf32>, vector<128xf32>
+    %mask2 = "mask_def_1"() : () -> (vector<128x2xi1>)
+    %1 = vector.transfer_read %src[%c0, %index], %f0, %mask2 {in_bounds = [true, true]} : memref<4096x4096xf32>, vector<128x2xf32>
+    vector.yield %0, %1 : vector<128xf32>, vector<128x2xf32>
+  }
+  return %r#0, %r#1 : vector<2xf32>, vector<2x2xf32>
+}
+
+//   CHECK-PROP-DAG: #[[$MAP0:.+]] = affine_map<()[s0] -> (s0 * 2)>
+//   CHECK-PROP-DAG: #[[$MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 2)>
+// CHECK-PROP-LABEL: func.func @warp_propagate_masked_transfer_read
+//  CHECK-PROP-SAME:   %[[ARG0:.+]]: index, {{.*}}, %[[ARG2:.+]]: index
+//       CHECK-PROP:   %[[C0:.*]] = arith.constant 0 : index
+//       CHECK-PROP:   %[[R:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[64] -> (vector<2xi1>, vector<2x2xi1>) {
+//       CHECK-PROP:     %[[M0:.*]] = "mask_def_0"
+//       CHECK-PROP:     %[[M1:.*]] = "mask_def_1"
+//       CHECK-PROP:     vector.yield %[[M0]], %[[M1]] : vector<128xi1>, vector<128x2xi1>
+//       CHECK-PROP:   }
+//       CHECK-PROP:   %[[DIST_READ_IDX0:.+]] = affine.apply #[[$MAP0]]()[%[[ARG0]]]
+//       CHECK-PROP:   vector.transfer_read {{.*}}[%[[DIST_READ_IDX0]], %[[ARG2]]], {{.*}}, %[[R]]#1 {{.*}} vector<2x2xf32>
+//       CHECK-PROP:   %[[DIST_READ_IDX1:.+]] = affine.apply #[[$MAP1]]()[%[[ARG2]], %[[ARG0]]]
+//       CHECK-PROP:   vector.transfer_read {{.*}}[%[[C0]], %[[DIST_READ_IDX1]]], {{.*}}, %[[R]]#0 {{.*}} vector<2xf32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/71610


More information about the Mlir-commits mailing list