[Mlir-commits] [mlir] d4d2891 - [mlir][vector] Add distribution pattern for vector.create_mask (#71619)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 10 07:09:41 PST 2023
Author: Quinn Dawkins
Date: 2023-11-10T10:09:37-05:00
New Revision: d4d289144764f4d293874d8959cc0d65ff79148f
URL: https://github.com/llvm/llvm-project/commit/d4d289144764f4d293874d8959cc0d65ff79148f
DIFF: https://github.com/llvm/llvm-project/commit/d4d289144764f4d293874d8959cc0d65ff79148f.diff
LOG: [mlir][vector] Add distribution pattern for vector.create_mask (#71619)
This is the last step needed for basic support for distributing masked
vector code. The lane id gets delinearized based on the distributed mask
shape and then compared against the original mask sizes to compute the
bounds for the distributed mask. Note that the distribution of masks is
implicit on the shape specified by the warp op. As a result, it is the
responsibility of the consumer of the mask to ensure the distributed
mask will match its own distribution semantics.
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 334d23e08419cea..645caa9c1378821 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1078,6 +1078,83 @@ struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
}
};
+/// Sink out vector.create_mask op feeding into a warp op yield.
+/// ```
+/// %0 = ...
+/// %1 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
+/// ...
+/// %mask = vector.create_mask %0 : vector<32xi1>
+/// vector.yield %mask : vector<32xi1>
+/// }
+/// ```
+/// To
+/// ```
+/// %0 = ...
+/// vector.warp_execute_on_lane_0(%arg0) {
+/// ...
+/// }
+/// %cmp = arith.cmpi ult, %laneid, %0
+/// %ub = arith.select %cmp, %c0, %c1
+/// %1 = vector.create_mask %ub : vector<1xi1>
+struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *yieldOperand = getWarpResult(
+ warpOp, [](Operation *op) { return isa<vector::CreateMaskOp>(op); });
+ if (!yieldOperand)
+ return failure();
+
+ auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
+
+ // Early exit if any values needed for calculating the new mask indices
+ // are defined inside the warp op.
+ if (!llvm::all_of(mask->getOperands(), [&](Value value) {
+ return warpOp.isDefinedOutsideOfRegion(value);
+ }))
+ return failure();
+
+ Location loc = mask.getLoc();
+ unsigned operandIndex = yieldOperand->getOperandNumber();
+
+ auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
+ VectorType seqType = mask.getVectorType();
+ ArrayRef<int64_t> seqShape = seqType.getShape();
+ ArrayRef<int64_t> distShape = distType.getShape();
+
+ rewriter.setInsertionPointAfter(warpOp);
+
+ // Delinearize the lane ID for constructing the distributed mask sizes.
+ SmallVector<Value> delinearizedIds;
+ if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
+ warpOp.getWarpSize(), warpOp.getLaneid(),
+ delinearizedIds))
+ return rewriter.notifyMatchFailure(
+ mask, "cannot delinearize lane ID for distribution");
+ assert(!delinearizedIds.empty());
+
+ AffineExpr s0, s1;
+ bindSymbols(rewriter.getContext(), s0, s1);
+ SmallVector<Value> newOperands;
+ for (int i = 0, e = distShape.size(); i < e; ++i) {
+ // Get `mask_dim_range_upper_limit[i] - lane_id[i] * dist_sizes[i]` to
+ // find the distance from the largest mask index owned by this lane to the
+ // original mask size. `vector.create_mask` implicitly clamps mask
+ // operands to the range [0, mask_vector_size[i]], or in other words, the
+ // mask sizes are always in the range [0, mask_vector_size[i]).
+ Value maskDimIdx = affine::makeComposedAffineApply(
+ rewriter, loc, s1 - s0 * distShape[i],
+ {delinearizedIds[i], mask.getOperand(i)});
+ newOperands.push_back(maskDimIdx);
+ }
+
+ auto newMask =
+ rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands);
+ rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
+ return success();
+ }
+};
+
/// Pattern to move out vector.extract of single element vector. Those don't
/// need to be distributed and can just be propagated outside of the region.
struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
@@ -1731,10 +1808,11 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
PatternBenefit readBenefit) {
patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
- patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
- WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
- WarpOpConstant, WarpOpInsertElement, WarpOpInsert>(
- patterns.getContext(), benefit);
+ patterns
+ .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
+ WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
+ WarpOpInsertElement, WarpOpInsert, WarpOpCreateMask>(
+ patterns.getContext(), benefit);
patterns.add<WarpOpExtractElement>(patterns.getContext(),
warpShuffleFromIdxFn, benefit);
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 6d8ad5a0e88c2bd..1821190c44e3af4 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1369,3 +1369,37 @@ func.func @warp_propagate_unconnected_read_write(%laneid: index, %buffer: memref
// CHECK-DIST-AND-PROP: %[[CST:.+]] = arith.constant dense<2.000000e+00> : vector<4xf32>
// CHECK-DIST-AND-PROP: vector.transfer_read {{.*}} : memref<128xf32>, vector<4xf32>
// CHECK-DIST-AND-PROP: vector.transfer_write %[[CST]], {{.*}} : vector<4xf32>, memref<128xf32>
+
+// -----
+
+func.func @warp_propagate_create_mask(%laneid: index, %m0: index) -> vector<1xi1> {
+ %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xi1>) {
+ %1 = vector.create_mask %m0 : vector<32xi1>
+ vector.yield %1 : vector<32xi1>
+ }
+ return %r : vector<1xi1>
+}
+
+// CHECK-PROP-DAG: #[[$SUB:.*]] = affine_map<()[s0, s1] -> (-s0 + s1)>
+// CHECK-PROP-LABEL: func @warp_propagate_create_mask
+// CHECK-PROP-SAME: %[[LANEID:.+]]: index, %[[M0:.+]]: index
+// CHECK-PROP: %[[MDIST:.+]] = affine.apply #[[$SUB]]()[%[[LANEID]], %[[M0]]]
+// CHECK-PROP: vector.create_mask %[[MDIST]] : vector<1xi1>
+
+// -----
+
+func.func @warp_propagate_multi_dim_create_mask(%laneid: index, %m0: index, %m1: index, %m2: index) -> vector<1x2x4xi1> {
+ %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x2x4xi1>) {
+ %1 = vector.create_mask %m0, %m1, %m2 : vector<16x4x4xi1>
+ vector.yield %1 : vector<16x4x4xi1>
+ }
+ return %r : vector<1x2x4xi1>
+}
+
+// CHECK-PROP-DAG: #[[$SUBM0:.*]] = affine_map<()[s0, s1] -> (s0 - s1 floordiv 2)>
+// CHECK-PROP-DAG: #[[$SUBM1:.*]] = affine_map<()[s0, s1] -> (s0 - s1 * 2 + (s1 floordiv 2) * 4)>
+// CHECK-PROP-LABEL: func @warp_propagate_multi_dim_create_mask
+// CHECK-PROP-SAME: %[[LANEID:.+]]: index, %[[M0:.+]]: index, %[[M1:.+]]: index, %[[M2:.+]]: index
+// CHECK-PROP: %[[DISTM0:.+]] = affine.apply #[[$SUBM0]]()[%[[M0]], %[[LANEID]]]
+// CHECK-PROP: %[[DISTM1:.+]] = affine.apply #[[$SUBM1]]()[%[[M1]], %[[LANEID]]]
+// CHECK-PROP: vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[M2]] : vector<1x2x4xi1>
More information about the Mlir-commits
mailing list