[Mlir-commits] [mlir] [mlir][vector] Add distribution pattern for vector.create_mask (PR #71619)
Quinn Dawkins
llvmlistbot at llvm.org
Wed Nov 8 07:52:32 PST 2023
================
@@ -1047,6 +1047,82 @@ struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
}
};
+/// Sink out vector.create_mask op feeding into a warp op yield.
+/// ```
+/// %0 = ...
+/// %1 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
+/// ...
+/// %mask = vector.create_mask %0 : vector<32xi1>
+/// vector.yield %mask : vector<32xi1>
+/// }
+/// ```
+/// To
+/// ```
+/// %0 = ...
+/// vector.warp_execute_on_lane_0(%arg0) {
+/// ...
+/// }
+/// %cmp = arith.cmpi ult, %laneid, %0
+/// %ub = arith.select %cmp, %c0, %c1
+/// %1 = vector.create_mask %ub : vector<1xi1>
+struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
+ using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *yieldOperand = getWarpResult(
+ warpOp, [](Operation *op) { return isa<vector::CreateMaskOp>(op); });
+ if (!yieldOperand)
+ return failure();
+
+ auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
+
+ // Early exit if any values needed for calculating the new mask indices
+ // are defined inside the warp op.
+ if (!llvm::all_of(mask->getOperands(), [&](Value value) {
+ return warpOp.isDefinedOutsideOfRegion(value);
+ }))
+ return failure();
+
+ Location loc = mask.getLoc();
+ unsigned operandIndex = yieldOperand->getOperandNumber();
+
+ auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
+ VectorType seqType = mask.getVectorType();
+ auto seqShape = seqType.getShape();
+ auto distShape = distType.getShape();
+
+ rewriter.setInsertionPointAfter(warpOp);
+
+ // Delinearize the lane ID for constructing the distributed mask sizes.
+ SmallVector<Value> delinearizedIds;
+ if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
+ warpOp.getWarpSize(), warpOp.getLaneid(),
+ delinearizedIds))
+ return rewriter.notifyMatchFailure(
+ mask, "cannot delinearize lane ID for distribution");
+ assert(!delinearizedIds.empty());
+
+ AffineExpr s0, s1;
+ bindSymbols(rewriter.getContext(), s0, s1);
+ SmallVector<Value> newOperands;
+ for (int i = 0, e = distShape.size(); i < e; ++i) {
+ // Get `mask_size[i] - lane_id[i] * (seq_sizes[i]/dist_sizes[i])` to find
+ // the distance from the largest mask index owned by this lane to the
+ // original mask size. vector.create_mask implicitly clamps mask sizes to
+ // the range [0, mask_vector_size[i]].
----------------
qedawkins wrote:
I reread the semantics for create_mask, and values are implicitly clamped to a valid range, in contrast to constant_mask which requires the constant values to be within [0, mask_size], as well as zero sizes must propagate through the whole mask. This lets me drop the explicit clamping I was doing before.
https://github.com/llvm/llvm-project/pull/71619
More information about the Mlir-commits
mailing list