[Mlir-commits] [mlir] [mlir][vector] Add distribution pattern for vector.create_mask (PR #71619)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 7 17:57:08 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Quinn Dawkins (qedawkins)
<details>
<summary>Changes</summary>
This is the last step needed for basic support for distributing masked vector code. The lane id gets delinearized based on the distributed mask shape and then compared against the original mask sizes to compute the bounds for the distributed mask.
---
Full diff: https://github.com/llvm/llvm-project/pull/71619.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+83-1)
- (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+60)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e128cc71a5d628c..9d5354300efe31a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1047,6 +1047,88 @@ struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
}
};
+/// Sink out vector.create_mask op feeding into a warp op yield.
+/// ```
+/// %0 = ...
+/// %1 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
+/// ...
+/// %mask = vector.create_mask %0 : vector<32xi1>
+/// vector.yield %mask : vector<32xi1>
+/// }
+/// ```
+/// To
+/// ```
+/// %0 = ...
+/// vector.warp_execute_on_lane_0(%arg0) {
+/// ...
+/// }
+/// %cmp = arith.cmpi ult, %laneid, %0
+/// %ub = arith.select %cmp, %c0, %c1
+/// %1 = vector.create_mask %ub : vector<1xi1>
+struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
+ using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *yieldOperand = getWarpResult(
+ warpOp, [](Operation *op) { return isa<vector::CreateMaskOp>(op); });
+ if (!yieldOperand)
+ return failure();
+
+ auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
+ Location loc = mask.getLoc();
+ unsigned operandIndex = yieldOperand->getOperandNumber();
+
+ auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
+ VectorType seqType = mask.getVectorType();
+ auto seqShape = seqType.getShape();
+ auto distShape = distType.getShape();
+
+ rewriter.setInsertionPointAfter(warpOp);
+
+ // Delinearize the lane ID for constructing the distributed mask sizes.
+ SmallVector<Value> delinearizedIds;
+ if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
+ warpOp.getWarpSize(), warpOp.getLaneid(),
+ delinearizedIds))
+ return rewriter.notifyMatchFailure(
+ mask, "cannot delinearize lane ID for distribution");
+ assert(!delinearizedIds.empty());
+
+ AffineExpr s0, s1;
+ bindSymbols(rewriter.getContext(), s0, s1);
+ Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value prevIsZero;
+ SmallVector<Value> newOperands;
+ for (int i = 0, e = distShape.size(); i < e; ++i) {
+ // Get `mask_size[i] - lane_id[i] * (seq_sizes[i]/dist_sizes[i])` to find
+ // the distance from the largest mask index owned by this lane to the
+ // original mask size.
+ Value maskDimIdx = affine::makeComposedAffineApply(
+ rewriter, loc, s1 - s0 * distShape[i],
+ {delinearizedIds[i], mask.getOperand(i)});
+ // Clamp to the range [0, dist_sizes[i]].
+ Value clampZero = rewriter.create<arith::MaxSIOp>(loc, zero, maskDimIdx);
+ Value vecSizeVal = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIndexAttr(distShape[i]));
+ Value clampVecSize =
+ rewriter.create<arith::MinSIOp>(loc, vecSizeVal, clampZero);
+ // If a previous mask size is zero, all trailing sizes must also be zero.
+ if (prevIsZero) {
+ clampVecSize = rewriter.create<arith::SelectOp>(loc, prevIsZero, zero,
+ clampVecSize);
+ }
+ prevIsZero = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
+ clampVecSize, zero);
+ newOperands.push_back(clampVecSize);
+ }
+
+ auto newMask =
+ rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands);
+ rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
+ return success();
+ }
+};
+
/// Pattern to move out vector.extract of single element vector. Those don't
/// need to be distributed and can just be propagated outside of the region.
struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
@@ -1701,7 +1783,7 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
WarpOpBroadcast, WarpOpShapeCast, WarpOpExtract,
WarpOpForwardOperand, WarpOpConstant, WarpOpInsertElement,
- WarpOpInsert>(patterns.getContext(), benefit);
+ WarpOpCreateMask, WarpOpInsert>(patterns.getContext(), benefit);
patterns.add<WarpOpExtractElement>(patterns.getContext(),
warpShuffleFromIdxFn, benefit);
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index f050bcd246e5ef7..b8c6b7fa99418d0 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1281,3 +1281,63 @@ func.func @warp_propagate_masked_write(%laneid: index, %dest: memref<4096xf32>)
// CHECK-DIST-AND-PROP: }
// CHECK-DIST-AND-PROP: vector.transfer_write %[[W]]#2, {{.*}}, %[[W]]#3 {in_bounds = [true]} : vector<128xf32>, memref<4096xf32>
// CHECK-DIST-AND-PROP: vector.transfer_write %[[W]]#0, {{.*}}, %[[W]]#1 {in_bounds = [true]} : vector<1xf32>, memref<4096xf32>
+
+// -----
+
+func.func @warp_propagate_create_mask(%laneid: index, %m0: index) -> vector<1xi1> {
+ %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xi1>) {
+ %1 = vector.create_mask %m0 : vector<32xi1>
+ vector.yield %1 : vector<32xi1>
+ }
+ return %r : vector<1xi1>
+}
+
+// CHECK-PROP-DAG: #[[$SUB:.*]] = affine_map<()[s0, s1] -> (-s0 + s1)>
+// CHECK-PROP-LABEL: func @warp_propagate_create_mask
+// CHECK-PROP-SAME: %[[LANEID:.+]]: index, %[[M0:.+]]: index
+// CHECK-PROP: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-PROP: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-PROP: %[[MBOUNDDIST:.+]] = affine.apply #[[$SUB]]()[%[[LANEID]], %[[M0]]]
+// CHECK-PROP: %[[CLAMPZERO:.+]] = arith.maxsi %[[MBOUNDDIST]], %[[C0]] : index
+// CHECK-PROP: %[[MDIST:.+]] = arith.minsi %[[CLAMPZERO]], %[[C1]] : index
+// CHECK-PROP: vector.create_mask %[[MDIST]] : vector<1xi1>
+
+// -----
+
+func.func @warp_propagate_multi_dim_create_mask(%laneid: index, %m0: index, %m1: index, %m2: index) -> vector<1x2x4xi1> {
+ %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x2x4xi1>) {
+ %1 = vector.create_mask %m0, %m1, %m2 : vector<16x4x4xi1>
+ vector.yield %1 : vector<16x4x4xi1>
+ }
+ return %r : vector<1x2x4xi1>
+}
+
+// CHECK-PROP-DAG: #[[$SUBM0:.*]] = affine_map<()[s0, s1] -> (s0 - s1 floordiv 2)>
+// CHECK-PROP-DAG: #[[$SUBM1:.*]] = affine_map<()[s0, s1] -> (s0 - s1 * 2 + (s1 floordiv 2) * 4)>
+// CHECK-PROP-LABEL: func @warp_propagate_multi_dim_create_mask
+// CHECK-PROP-SAME: %[[LANEID:.+]]: index, %[[M0:.+]]: index, %[[M1:.+]]: index, %[[M2:.+]]: index
+// CHECK-PROP-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-PROP-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-PROP-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-PROP-DAG: %[[C4:.+]] = arith.constant 4 : index
+
+// Compute distributed m0 based on the first (outermost) delinearized lane id.
+// CHECK-PROP: affine.apply #map()[%[[M0]], %[[LANEID]]]
+// CHECK-PROP: arith.maxsi {{.*}}, %[[C0]] : index
+// CHECK-PROP: %[[DISTM0:.+]] = arith.minsi {{.*}}, %[[C1]] : index
+// CHECK-PROP: arith.cmpi eq, {{.*}}, %[[C0]] : index
+
+// Compute m1 based on the second delinearized lane id. If m0 is zero, m1 must
+// also be zero for the mask to be valid.
+// CHECK-PROP: affine.apply #map1()[%[[M1]], %[[LANEID]]]
+// CHECK-PROP: arith.maxsi {{.*}}, %[[C0]] : index
+// CHECK-PROP: arith.minsi {{.*}}, %[[C2]] : index
+// CHECK-PROP: %[[DISTM1:.+]] = arith.select {{.*}}, %[[C0]], {{.*}} : index
+
+// Compute m3 and propagate zeros.
+// CHECK-PROP: arith.cmpi eq, {{.*}}, %[[C0]] : index
+// CHECK-PROP: arith.maxsi %[[M2]], %[[C0]] : index
+// CHECK-PROP: arith.minsi {{.*}}, %[[C4]] : index
+// CHECK-PROP: %[[DISTM2:.+]] = arith.select {{.*}}, %[[C0]], {{.*}} : index
+
+// CHECK-PROP: vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[DISTM2]] : vector<1x2x4xi1>
``````````
</details>
https://github.com/llvm/llvm-project/pull/71619
More information about the Mlir-commits
mailing list