[Mlir-commits] [mlir] [MLIR] [XeGPU] Add distribution pattern for vector.constant_mask from Wg To Sg (PR #168118)

Charitha Saumya llvmlistbot at llvm.org
Thu Nov 20 11:45:43 PST 2025


================
@@ -1285,6 +1285,71 @@ struct WgToSgVectorTransposeOp
   }
 };
 
+// This pattern distributes the vector.constant_mask ops to work at subgroup
+// level.
+struct WgToSgVectorConstantMaskOp
+    : public OpConversionPattern<vector::ConstantMaskOp> {
+  using OpConversionPattern<vector::ConstantMaskOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ConstantMaskOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getDistributeLayoutAttr(op.getResult());
+    if (!layout || !layout.isForWorkgroup())
+      return failure();
+
+    Location loc = op.getLoc();
+    VectorType type = op.getResult().getType();
+    auto wgShape = type.getShape();
+
+    ArrayRef<int64_t> wgMaskDimSizes = op.getMaskDimSizes();
+
+    // Get subgroup ID.
+    Value sgId =
+        gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+    auto sgOffsets =
+        layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
+    if (failed(sgOffsets))
+      return failure();
+
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+    VectorType resultType = VectorType::get(sgShape, type.getElementType());
+
+    // Each subgroup computes its local mask size as: min(max(wgMaskSize -
----------------
charithaintc wrote:

say the "in each dimension" + reword the equation to denote dimensions. 

https://github.com/llvm/llvm-project/pull/168118


More information about the Mlir-commits mailing list