[Mlir-commits] [mlir] aba8ebb - [MLIR][Vector] Add distribution pattern for `vector::ConstantMaskOp` (#172268)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Dec 16 08:24:17 PST 2025
Author: Artem Kroviakov
Date: 2025-12-16T17:24:13+01:00
New Revision: aba8ebbda0912ef2037668aaa48cfbe59991576f
URL: https://github.com/llvm/llvm-project/commit/aba8ebbda0912ef2037668aaa48cfbe59991576f
DIFF: https://github.com/llvm/llvm-project/commit/aba8ebbda0912ef2037668aaa48cfbe59991576f.diff
LOG: [MLIR][Vector] Add distribution pattern for `vector::ConstantMaskOp` (#172268)
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 8b5e950733a22..5334470e2e3a0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1100,12 +1100,14 @@ struct WarpOpShapeCast : public WarpDistributionPattern {
}
};
-/// Sink out vector.create_mask op feeding into a warp op yield.
+/// Sink out vector.create_mask / vector.constant_mask op feeding into a warp op
+/// yield.
/// ```
/// %0 = ...
/// %1 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
/// ...
/// %mask = vector.create_mask %0 : vector<32xi1>
+/// // or %mask = vector.constant_mask[2] : vector<32xi1>
/// gpu.yield %mask : vector<32xi1>
/// }
/// ```
@@ -1118,31 +1120,45 @@ struct WarpOpShapeCast : public WarpDistributionPattern {
/// %cmp = arith.cmpi ult, %laneid, %0
/// %ub = arith.select %cmp, %c0, %c1
/// %1 = vector.create_mask %ub : vector<1xi1>
+template <typename OpType,
+ typename = std::enable_if_t<llvm::is_one_of<
+ OpType, vector::CreateMaskOp, vector::ConstantMaskOp>::value>>
struct WarpOpCreateMask : public WarpDistributionPattern {
using Base::Base;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- OpOperand *yieldOperand =
- getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>);
+ OpOperand *yieldOperand = getWarpResult(warpOp, (llvm::IsaPred<OpType>));
if (!yieldOperand)
return failure();
- auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
+ Operation *mask = yieldOperand->get().getDefiningOp<OpType>();
// Early exit if any values needed for calculating the new mask indices
// are defined inside the warp op.
- if (!llvm::all_of(mask->getOperands(), [&](Value value) {
+ if (mask->getOperands().size() &&
+ !llvm::all_of(mask->getOperands(), [&](Value value) {
return warpOp.isDefinedOutsideOfRegion(value);
}))
return failure();
- Location loc = mask.getLoc();
+ Location loc = mask->getLoc();
unsigned operandIndex = yieldOperand->getOperandNumber();
auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
- VectorType seqType = mask.getVectorType();
+ VectorType seqType = cast<VectorType>(mask->getResult(0).getType());
ArrayRef<int64_t> seqShape = seqType.getShape();
ArrayRef<int64_t> distShape = distType.getShape();
+ SmallVector<Value> materializedOperands;
+ if constexpr (std::is_same_v<OpType, vector::CreateMaskOp>) {
+ materializedOperands.append(mask->getOperands().begin(),
+ mask->getOperands().end());
+ } else {
+ auto constantMaskOp = cast<vector::ConstantMaskOp>(mask);
+ auto dimSizes = constantMaskOp.getMaskDimSizesAttr().asArrayRef();
+ for (auto dimSize : dimSizes)
+ materializedOperands.push_back(
+ arith::ConstantIndexOp::create(rewriter, loc, dimSize).getResult());
+ }
rewriter.setInsertionPointAfter(warpOp);
@@ -1170,7 +1186,7 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
// mask sizes are always in the range [0, mask_vector_size[i]).
Value maskDimIdx = affine::makeComposedAffineApply(
rewriter, loc, s1 - s0 * distShape[i],
- {delinearizedIds[i], mask.getOperand(i)});
+ {delinearizedIds[i], materializedOperands[i]});
newOperands.push_back(maskDimIdx);
}
@@ -2282,12 +2298,13 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
PatternBenefit readBenefit) {
patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
- patterns
- .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
- WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
- WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,
- WarpOpExtractStridedSlice, WarpOpInsertStridedSlice, WarpOpStep>(
- patterns.getContext(), benefit);
+ patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
+ WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
+ WarpOpConstant, WarpOpInsertScalar, WarpOpInsert,
+ WarpOpCreateMask<vector::CreateMaskOp>,
+ WarpOpCreateMask<vector::ConstantMaskOp>,
+ WarpOpExtractStridedSlice, WarpOpInsertStridedSlice, WarpOpStep>(
+ patterns.getContext(), benefit);
patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
benefit);
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 0cf6dd151e16c..135db02d543ef 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1779,6 +1779,21 @@ func.func @warp_propagate_unconnected_read_write(%laneid: index, %buffer: memref
// CHECK-DIST-AND-PROP: %[[CST:.+]] = arith.constant dense<2.000000e+00> : vector<4xf32>
// CHECK-DIST-AND-PROP: vector.transfer_read {{.*}} : memref<128xf32>, vector<4xf32>
// CHECK-DIST-AND-PROP: vector.transfer_write %[[CST]], {{.*}} : vector<4xf32>, memref<128xf32>
+// -----
+
+func.func @warp_propagate_constant_mask(%laneid: index) -> vector<1xi1> {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xi1>) {
+ %1 = vector.constant_mask [1] : vector<32xi1>
+ gpu.yield %1 : vector<32xi1>
+ }
+ return %r : vector<1xi1>
+}
+
+// CHECK-PROP-DAG: #[[$SUB:.*]] = affine_map<()[s0] -> (-s0 + 1)>
+// CHECK-PROP-LABEL: func @warp_propagate_constant_mask
+// CHECK-PROP-SAME: %[[LANEID:.+]]: index
+// CHECK-PROP: %[[MDIST:.+]] = affine.apply #[[$SUB]]()[%[[LANEID]]]
+// CHECK-PROP: vector.create_mask %[[MDIST]] : vector<1xi1>
// -----
@@ -1813,6 +1828,24 @@ func.func @warp_propagate_multi_dim_create_mask(%laneid: index, %m0: index, %m1:
// CHECK-PROP: %[[DISTM0:.+]] = affine.apply #[[$SUBM0]]()[%[[M0]], %[[LANEID]]]
// CHECK-PROP: %[[DISTM1:.+]] = affine.apply #[[$SUBM1]]()[%[[M1]], %[[LANEID]]]
// CHECK-PROP: vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[M2]] : vector<1x2x4xi1>
+// -----
+
+func.func @warp_propagate_multi_dim_constant_mask(%laneid: index) -> vector<1x2x4xi1> {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x2x4xi1>) {
+ %1 = vector.constant_mask [1, 1, 2]: vector<16x4x4xi1>
+ gpu.yield %1 : vector<16x4x4xi1>
+ }
+ return %r : vector<1x2x4xi1>
+}
+
+// CHECK-PROP-DAG: #[[$SUBM0:.*]] = affine_map<()[s0] -> (-(s0 floordiv 2) + 1)>
+// CHECK-PROP-DAG: #[[$SUBM1:.*]] = affine_map<()[s0] -> (s0 * -2 + (s0 floordiv 2) * 4 + 1)>
+// CHECK-PROP-LABEL: func @warp_propagate_multi_dim_constant_mask
+// CHECK-PROP-SAME: %[[LANEID:.+]]: index
+// CHECK-PROP: %[[CST2:.+]] = arith.constant 2 : index
+// CHECK-PROP: %[[DISTM0:.+]] = affine.apply #[[$SUBM0]]()[%[[LANEID]]]
+// CHECK-PROP: %[[DISTM1:.+]] = affine.apply #[[$SUBM1]]()[%[[LANEID]]]
+// CHECK-PROP: vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[CST2]] : vector<1x2x4xi1>
// -----
More information about the Mlir-commits
mailing list