[Mlir-commits] [mlir] [mlir][vector] Add distribution pattern for vector.create_mask (PR #71619)

Lei Zhang llvmlistbot at llvm.org
Thu Nov 9 21:16:59 PST 2023


================
@@ -1071,6 +1071,82 @@ struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
   }
 };
 
+/// Sink out vector.create_mask op feeding into a warp op yield.
+/// ```
+/// %0 = ...
+/// %1 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
+///   ...
+///   %mask = vector.create_mask %0 : vector<32xi1>
+///   vector.yield %mask : vector<32xi1>
+/// }
+/// ```
+/// To
+/// ```
+/// %0 = ...
+/// vector.warp_execute_on_lane_0(%arg0) {
+///   ...
+/// }
+/// %cmp = arith.cmpi ult, %laneid, %0
+/// %ub = arith.select %cmp, %c0, %c1
+/// %1 = vector.create_mask %ub : vector<1xi1>
+struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
+  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *yieldOperand = getWarpResult(
+        warpOp, [](Operation *op) { return isa<vector::CreateMaskOp>(op); });
+    if (!yieldOperand)
+      return failure();
+
+    auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
+
+    // Early exit if any values needed for calculating the new mask indices
+    // are defined inside the warp op.
+    if (!llvm::all_of(mask->getOperands(), [&](Value value) {
+          return warpOp.isDefinedOutsideOfRegion(value);
+        }))
+      return failure();
+
+    Location loc = mask.getLoc();
+    unsigned operandIndex = yieldOperand->getOperandNumber();
+
+    auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
+    VectorType seqType = mask.getVectorType();
+    auto seqShape = seqType.getShape();
+    auto distShape = distType.getShape();
+
+    rewriter.setInsertionPointAfter(warpOp);
+
+    // Delinearize the lane ID for constructing the distributed mask sizes.
+    SmallVector<Value> delinearizedIds;
+    if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
+                           warpOp.getWarpSize(), warpOp.getLaneid(),
+                           delinearizedIds))
+      return rewriter.notifyMatchFailure(
+          mask, "cannot delinearize lane ID for distribution");
+    assert(!delinearizedIds.empty());
+
+    AffineExpr s0, s1;
+    bindSymbols(rewriter.getContext(), s0, s1);
+    SmallVector<Value> newOperands;
+    for (int i = 0, e = distShape.size(); i < e; ++i) {
+      // Get `mask_size[i] - lane_id[i] * (seq_sizes[i]/dist_sizes[i])` to find
----------------
antiagainst wrote:

The comment seems wrong to me. This should be `mask_dim_range_upper_limit[i] - lane[i] * distributed_shape[i]`?

https://github.com/llvm/llvm-project/pull/71619


More information about the Mlir-commits mailing list