[Mlir-commits] [mlir] [mlir][vector] Add distribution pattern for vector.create_mask (PR #71619)

Quinn Dawkins llvmlistbot at llvm.org
Tue Nov 7 17:56:40 PST 2023


https://github.com/qedawkins created https://github.com/llvm/llvm-project/pull/71619

This is the last step needed for basic support for distributing masked vector code. The lane id gets delinearized based on the distributed mask shape and then compared against the original mask sizes to compute the bounds for the distributed mask.

>From 3f7fea5f98e4dbd3476d64b95a7c475c7ce80165 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Sun, 5 Nov 2023 21:22:07 -0500
Subject: [PATCH] [MLIR][Vector] Add distribution pattern for
 vector.create_mask

This is the last step needed for basic support for distributing masked
vector code. The lane id gets delinearized based on the distributed mask
shape and then compared against the original mask sizes to compute the
bounds for the distributed mask.
---
 .../Vector/Transforms/VectorDistribute.cpp    | 84 ++++++++++++++++++-
 .../Vector/vector-warp-distribute.mlir        | 60 +++++++++++++
 2 files changed, 143 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e128cc71a5d628c..9d5354300efe31a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1047,6 +1047,88 @@ struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
   }
 };
 
+/// Sink out vector.create_mask op feeding into a warp op yield.
+/// ```
+/// %0 = ...
+/// %1 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
+///   ...
+///   %mask = vector.create_mask %0 : vector<32xi1>
+///   vector.yield %mask : vector<32xi1>
+/// }
+/// ```
+/// To
+/// ```
+/// %0 = ...
+/// vector.warp_execute_on_lane_0(%arg0) {
+///   ...
+/// }
+/// %cmp = arith.cmpi ult, %laneid, %0
+/// %ub = arith.select %cmp, %c0, %c1
+/// %1 = vector.create_mask %ub : vector<1xi1>
+struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
+  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *yieldOperand = getWarpResult(
+        warpOp, [](Operation *op) { return isa<vector::CreateMaskOp>(op); });
+    if (!yieldOperand)
+      return failure();
+
+    auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
+    Location loc = mask.getLoc();
+    unsigned operandIndex = yieldOperand->getOperandNumber();
+
+    auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
+    VectorType seqType = mask.getVectorType();
+    auto seqShape = seqType.getShape();
+    auto distShape = distType.getShape();
+
+    rewriter.setInsertionPointAfter(warpOp);
+
+    // Delinearize the lane ID for constructing the distributed mask sizes.
+    SmallVector<Value> delinearizedIds;
+    if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
+                           warpOp.getWarpSize(), warpOp.getLaneid(),
+                           delinearizedIds))
+      return rewriter.notifyMatchFailure(
+          mask, "cannot delinearize lane ID for distribution");
+    assert(!delinearizedIds.empty());
+
+    AffineExpr s0, s1;
+    bindSymbols(rewriter.getContext(), s0, s1);
+    Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    Value prevIsZero;
+    SmallVector<Value> newOperands;
+    for (int i = 0, e = distShape.size(); i < e; ++i) {
+      // Get `mask_size[i] - lane_id[i] * (seq_sizes[i]/dist_sizes[i])` to find
+      // the distance from the largest mask index owned by this lane to the
+      // original mask size.
+      Value maskDimIdx = affine::makeComposedAffineApply(
+          rewriter, loc, s1 - s0 * distShape[i],
+          {delinearizedIds[i], mask.getOperand(i)});
+      // Clamp to the range [0, dist_sizes[i]].
+      Value clampZero = rewriter.create<arith::MaxSIOp>(loc, zero, maskDimIdx);
+      Value vecSizeVal = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getIndexAttr(distShape[i]));
+      Value clampVecSize =
+          rewriter.create<arith::MinSIOp>(loc, vecSizeVal, clampZero);
+      // If a previous mask size is zero, all trailing sizes must also be zero.
+      if (prevIsZero) {
+        clampVecSize = rewriter.create<arith::SelectOp>(loc, prevIsZero, zero,
+                                                        clampVecSize);
+      }
+      prevIsZero = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
+                                                  clampVecSize, zero);
+      newOperands.push_back(clampVecSize);
+    }
+
+    auto newMask =
+        rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands);
+    rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
+    return success();
+  }
+};
+
 /// Pattern to move out vector.extract of single element vector. Those don't
 /// need to be distributed and can just be propagated outside of the region.
 struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
@@ -1701,7 +1783,7 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
   patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
                WarpOpBroadcast, WarpOpShapeCast, WarpOpExtract,
                WarpOpForwardOperand, WarpOpConstant, WarpOpInsertElement,
-               WarpOpInsert>(patterns.getContext(), benefit);
+               WarpOpCreateMask, WarpOpInsert>(patterns.getContext(), benefit);
   patterns.add<WarpOpExtractElement>(patterns.getContext(),
                                      warpShuffleFromIdxFn, benefit);
   patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index f050bcd246e5ef7..b8c6b7fa99418d0 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1281,3 +1281,63 @@ func.func @warp_propagate_masked_write(%laneid: index, %dest: memref<4096xf32>)
 //       CHECK-DIST-AND-PROP:   }
 //       CHECK-DIST-AND-PROP:   vector.transfer_write %[[W]]#2, {{.*}}, %[[W]]#3 {in_bounds = [true]} : vector<128xf32>, memref<4096xf32>
 //       CHECK-DIST-AND-PROP:   vector.transfer_write %[[W]]#0, {{.*}}, %[[W]]#1 {in_bounds = [true]} : vector<1xf32>, memref<4096xf32>
+
+// -----
+
+func.func @warp_propagate_create_mask(%laneid: index, %m0: index) -> vector<1xi1> {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xi1>) {
+    %1 = vector.create_mask %m0 : vector<32xi1>
+    vector.yield %1 : vector<32xi1>
+  }
+  return %r : vector<1xi1>
+}
+
+//   CHECK-PROP-DAG: #[[$SUB:.*]] = affine_map<()[s0, s1] -> (-s0 + s1)>
+// CHECK-PROP-LABEL: func @warp_propagate_create_mask
+//  CHECK-PROP-SAME: %[[LANEID:.+]]: index, %[[M0:.+]]: index
+//       CHECK-PROP:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK-PROP:   %[[C1:.+]] = arith.constant 1 : index
+//       CHECK-PROP:   %[[MBOUNDDIST:.+]] = affine.apply #[[$SUB]]()[%[[LANEID]], %[[M0]]]
+//       CHECK-PROP:   %[[CLAMPZERO:.+]] = arith.maxsi %[[MBOUNDDIST]], %[[C0]] : index
+//       CHECK-PROP:   %[[MDIST:.+]] = arith.minsi %[[CLAMPZERO]], %[[C1]] : index
+//       CHECK-PROP:   vector.create_mask %[[MDIST]] : vector<1xi1>
+
+// -----
+
+func.func @warp_propagate_multi_dim_create_mask(%laneid: index, %m0: index, %m1: index, %m2: index) -> vector<1x2x4xi1> {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x2x4xi1>) {
+    %1 = vector.create_mask %m0, %m1, %m2 : vector<16x4x4xi1>
+    vector.yield %1 : vector<16x4x4xi1>
+  }
+  return %r : vector<1x2x4xi1>
+}
+
+//   CHECK-PROP-DAG: #[[$SUBM0:.*]] = affine_map<()[s0, s1] -> (s0 - s1 floordiv 2)>
+//   CHECK-PROP-DAG: #[[$SUBM1:.*]] = affine_map<()[s0, s1] -> (s0 - s1 * 2 + (s1 floordiv 2) * 4)>
+// CHECK-PROP-LABEL: func @warp_propagate_multi_dim_create_mask
+//  CHECK-PROP-SAME: %[[LANEID:.+]]: index, %[[M0:.+]]: index, %[[M1:.+]]: index, %[[M2:.+]]: index
+//   CHECK-PROP-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-PROP-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-PROP-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//   CHECK-PROP-DAG:   %[[C4:.+]] = arith.constant 4 : index
+
+// Compute distributed m0 based on the first (outermost) delinearized lane id.
+//       CHECK-PROP:   affine.apply #map()[%[[M0]], %[[LANEID]]]
+//       CHECK-PROP:   arith.maxsi {{.*}}, %[[C0]] : index
+//       CHECK-PROP:   %[[DISTM0:.+]] = arith.minsi {{.*}}, %[[C1]] : index
+//       CHECK-PROP:   arith.cmpi eq, {{.*}}, %[[C0]] : index
+
+// Compute m1 based on the second delinearized lane id. If m0 is zero, m1 must
+// also be zero for the mask to be valid.
+//       CHECK-PROP:   affine.apply #map1()[%[[M1]], %[[LANEID]]]
+//       CHECK-PROP:   arith.maxsi {{.*}}, %[[C0]] : index
+//       CHECK-PROP:   arith.minsi {{.*}}, %[[C2]] : index
+//       CHECK-PROP:   %[[DISTM1:.+]] = arith.select {{.*}}, %[[C0]], {{.*}} : index
+
+// Compute m3 and propagate zeros.
+//       CHECK-PROP:   arith.cmpi eq, {{.*}}, %[[C0]] : index
+//       CHECK-PROP:   arith.maxsi %[[M2]], %[[C0]] : index
+//       CHECK-PROP:   arith.minsi {{.*}}, %[[C4]] : index
+//       CHECK-PROP:   %[[DISTM2:.+]] = arith.select {{.*}}, %[[C0]], {{.*}} : index
+
+//       CHECK-PROP:   vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[DISTM2]] : vector<1x2x4xi1>



More information about the Mlir-commits mailing list