[Mlir-commits] [mlir] [MLIR] [XeGPU] Add distribution pattern for vector.constant_mask from Wg To Sg (PR #168118)
Charitha Saumya
llvmlistbot at llvm.org
Thu Nov 20 11:45:43 PST 2025
================
@@ -1285,6 +1285,71 @@ struct WgToSgVectorTransposeOp
}
};
+// This pattern distributes the vector.constant_mask ops to work at subgroup
+// level.
+struct WgToSgVectorConstantMaskOp
+ : public OpConversionPattern<vector::ConstantMaskOp> {
+ using OpConversionPattern<vector::ConstantMaskOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ConstantMaskOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ Location loc = op.getLoc();
+ VectorType type = op.getResult().getType();
+ auto wgShape = type.getShape();
+
+ ArrayRef<int64_t> wgMaskDimSizes = op.getMaskDimSizes();
+
+ // Get subgroup ID.
+ Value sgId =
+ gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+ auto sgOffsets =
+ layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
+ if (failed(sgOffsets))
+ return failure();
+
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ VectorType resultType = VectorType::get(sgShape, type.getElementType());
+
+ // Each subgroup computes its local mask size as: min(max(wgMaskSize -
----------------
charithaintc wrote:
say the "in each dimension" + reword the equation to denote dimensions.
https://github.com/llvm/llvm-project/pull/168118
More information about the Mlir-commits
mailing list