[Mlir-commits] [mlir] [MLIR][XeGPU] Add distribution for vector.create_mask from Wg to Sg (PR #169571)
Nishant Patel
llvmlistbot at llvm.org
Tue Nov 25 13:13:44 PST 2025
https://github.com/nbpatel created https://github.com/llvm/llvm-project/pull/169571
None
>From 3df7560f79d4d21990e11d039d13246fe74d50bd Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 25 Nov 2025 20:51:47 +0000
Subject: [PATCH 1/2] Add distribution for vector.create_mask op
---
.../Transforms/XeGPUWgToSgDistribute.cpp | 72 +++++++++++++++++--
.../XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir | 8 +++
.../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 37 ++++++++++
3 files changed, 113 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index beb9b60aa9d7a..b6de8b09ca8d6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1335,6 +1335,68 @@ struct WgToSgVectorConstantMaskOp
}
};
+// This pattern distributes the vector.create_mask ops to work at subgroup
+// level.
+struct WgToSgVectorCreateMaskOp
+ : public OpConversionPattern<vector::CreateMaskOp> {
+ using OpConversionPattern<vector::CreateMaskOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::CreateMaskOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ Location loc = op.getLoc();
+ VectorType type = op.getResult().getType();
+ auto wgShape = type.getShape();
+
+ auto wgMaskOperands = op.getOperands();
+
+ Value sgId =
+ gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+ auto sgOffsets =
+ layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
+ if (failed(sgOffsets))
+ return failure();
+
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ VectorType resultType = VectorType::get(sgShape, type.getElementType());
+
+ // In each dimension, each subgroup computes its local mask size as:
+ // min(max(wgMaskSize[d] - offset[d], 0), sgDimSize[d])
+ SmallVector<Value> newCreateMaskOps;
+ for (auto offsetSet : *sgOffsets) {
+ SmallVector<Value> maskOperands;
+
+ for (auto [i, wgMaskOperand] : llvm::enumerate(wgMaskOperands)) {
+ Value dimSizeVal =
+ arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
+ Value offset = offsetSet[i];
+ Value adjustedMaskSize =
+ arith::SubIOp::create(rewriter, loc, wgMaskOperand, offset);
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ Value nonNegative =
+ arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
+ Value sgMaskSize =
+ arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal);
+ maskOperands.push_back(sgMaskSize);
+ }
+
+ auto newCreateMaskOp =
+ vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
+ xegpu::setDistributeLayoutAttr(newCreateMaskOp->getResult(0),
+ layout.dropSgLayoutAndData());
+ newCreateMaskOps.push_back(newCreateMaskOp.getResult());
+ }
+
+ rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
+ return success();
+ }
+};
+
} // namespace
namespace mlir {
@@ -1350,7 +1412,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
- WgToSgVectorConstantMaskOp>(patterns.getContext());
+ WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
+ patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
@@ -1477,9 +1540,10 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(layout);
});
- target.addDynamicallyLegalOp<
- vector::ShapeCastOp, vector::StepOp, vector::TransposeOp,
- vector::BroadcastOp, vector::MultiDimReductionOp, vector::ConstantMaskOp>(
+ target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
+ vector::TransposeOp, vector::BroadcastOp,
+ vector::MultiDimReductionOp,
+ vector::ConstantMaskOp, vector::CreateMaskOp>(
[=](Operation *op) -> bool {
// Check for either a SliceAttr or LayoutAttr on the result.
auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index 1cddccb5fbbd1..4fb50b3b28534 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -138,5 +138,13 @@ gpu.module @test_distribution {
%constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1>
gpu.return
}
+
+ gpu.func @vector_create_mask_2D() {
+ // CHECK-COUNT-4: vector.create_mask {{.*}}, {{.*}} : vector<16x16xi1>
+ // CHECK-NOT: vector.create_mask
+ %cst16 = arith.constant 16 : index
+ %constant_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1>
+ gpu.return
+ }
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 574b365443a0a..48e93320093fd 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -583,6 +583,43 @@ gpu.module @test_distribution {
gpu.return
}
+ // CHECK-LABEL: vector_create_mask_1D
+ gpu.func @vector_create_mask_1D() {
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[REMU:.*]] = index.remu %[[SGID]], %[[C2:.*]]
+ // CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU]], %[[C16:.*]]
+ // CHECK-DAG: %[[REMU2:.*]] = index.remu %[[MUL]], %[[C32:.*]]
+ // CHECK-DAG: %[[SUB:.*]] = arith.subi %[[C8:.*]], %[[REMU2]] : index
+ // CHECK-DAG: %[[MAX:.*]] = arith.maxsi %[[SUB]], %[[C0:.*]] : index
+ // CHECK-DAG: %[[MIN:.*]] = arith.minsi %[[MAX]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MIN]] : vector<16xi1>
+ %cst8 = arith.constant 8 : index
+ %constant_mask = vector.create_mask %cst8 {layout_result_0 = #xegpu.layout<sg_layout = [2], sg_data = [16]>} : vector<32xi1>
+ gpu.return
+ }
+
+ // CHECK-LABEL: vector_create_mask_2D
+ gpu.func @vector_create_mask_2D() {
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[SGIDX:.*]] = index.remu %[[SGID]], %[[C4:.*]]
+ // CHECK-DAG: %[[SGIDY_TMP:.*]] = index.divu %[[SGID]], %[[C4:.*]]
+ // CHECK-DAG: %[[SGIDY:.*]] = index.remu %[[SGIDY_TMP]], %[[C8:.*]]
+ // CHECK-DAG: %[[ROW:.*]] = index.mul %[[SGIDY]], %[[C32:.*]]
+ // CHECK-DAG: %[[COL:.*]] = index.mul %[[SGIDX]], %[[C32:.*]]
+ // CHECK-DAG: %[[MODROW:.*]] = index.remu %[[ROW]], %[[C256:.*]]
+ // CHECK-DAG: %[[MODCOL:.*]] = index.remu %[[COL]], %[[C128:.*]]
+ // CHECK-DAG: %[[SUBROW:.*]] = arith.subi %[[C16:.*]], %[[MODROW]] : index
+ // CHECK-DAG: %[[MAXROW:.*]] = arith.maxsi %[[SUBROW]], %[[C0:.*]] : index
+ // CHECK-DAG: %[[MINROW:.*]] = arith.minsi %[[MAXROW]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[SUBCOL:.*]] = arith.subi %[[C16:.*]], %[[MODCOL]] : index
+ // CHECK-DAG: %[[MAXCOL:.*]] = arith.maxsi %[[SUBCOL]], %[[C0:.*]] : index
+ // CHECK-DAG: %[[MINCOL:.*]] = arith.minsi %[[MAXCOL]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MINROW]], %[[MINCOL]] : vector<32x32xi1>
+ %cst16 = arith.constant 16 : index
+ %constant_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xi1>
+ gpu.return
+ }
+
// CHECK-LABEL: distribute_load_slice_attr
gpu.func @distribute_load_slice_attr() {
%2 = memref.alloca() {alignment = 1024} : memref<4096xf32>
>From f9c930d4bf7ffcbcdb5ee81ca375cf6e5267dd99 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 25 Nov 2025 21:10:12 +0000
Subject: [PATCH 2/2] Templatize mask ops
---
.../Transforms/XeGPUWgToSgDistribute.cpp | 97 ++++---------------
1 file changed, 21 insertions(+), 76 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index b6de8b09ca8d6..95c20b1fabe58 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1270,15 +1270,15 @@ struct WgToSgVectorTransposeOp
}
};
-// This pattern distributes the vector.constant_mask ops to work at subgroup
-// level.
-struct WgToSgVectorConstantMaskOp
- : public OpConversionPattern<vector::ConstantMaskOp> {
- using OpConversionPattern<vector::ConstantMaskOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(vector::ConstantMaskOp op, OneToNOpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+// Distribute vector mask ops to work at subgroup level.
+template <typename MaskOpType>
+struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> {
+ using OpConversionPattern<MaskOpType>::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ MaskOpType op,
+ typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
xegpu::DistributeLayoutAttr layout =
xegpu::getDistributeLayoutAttr(op.getResult());
if (!layout || !layout.isForWorkgroup())
@@ -1288,73 +1288,16 @@ struct WgToSgVectorConstantMaskOp
VectorType type = op.getResult().getType();
auto wgShape = type.getShape();
- ArrayRef<int64_t> wgMaskDimSizes = op.getMaskDimSizes();
-
- // Get subgroup ID.
- Value sgId =
- gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
- auto sgOffsets =
- layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
- if (failed(sgOffsets))
- return failure();
-
- SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
- VectorType resultType = VectorType::get(sgShape, type.getElementType());
-
- // In each dimension, each subgroup computes its local mask size as:
- // min(max(wgMaskSize[d] - offset[d], 0), sgDimSize[d])
- SmallVector<Value> newCreateMaskOps;
- for (auto offsetSet : *sgOffsets) {
- SmallVector<Value> maskOperands;
-
- for (auto [i, wgMaskSize] : llvm::enumerate(wgMaskDimSizes)) {
- Value wgMaskSizeVal =
- arith::ConstantIndexOp::create(rewriter, loc, wgMaskSize);
- Value dimSizeVal =
- arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
- Value offset = offsetSet[i];
- Value adjustedMaskSize =
- arith::SubIOp::create(rewriter, loc, wgMaskSizeVal, offset);
- Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
- Value nonNegative =
- arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
- Value sgMaskSize =
- arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal);
- maskOperands.push_back(sgMaskSize);
+ SmallVector<Value> wgMaskDimSizes;
+ if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
+ for (int64_t maskSize : op.getMaskDimSizes()) {
+ wgMaskDimSizes.push_back(
+ arith::ConstantIndexOp::create(rewriter, loc, maskSize));
}
-
- auto newCreateMaskOp =
- vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
- xegpu::setDistributeLayoutAttr(newCreateMaskOp->getResult(0),
- layout.dropSgLayoutAndData());
- newCreateMaskOps.push_back(newCreateMaskOp.getResult());
+ } else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
+ wgMaskDimSizes = llvm::to_vector(op.getOperands());
}
- rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
- return success();
- }
-};
-
-// This pattern distributes the vector.create_mask ops to work at subgroup
-// level.
-struct WgToSgVectorCreateMaskOp
- : public OpConversionPattern<vector::CreateMaskOp> {
- using OpConversionPattern<vector::CreateMaskOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(vector::CreateMaskOp op, OneToNOpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- xegpu::DistributeLayoutAttr layout =
- xegpu::getDistributeLayoutAttr(op.getResult());
- if (!layout || !layout.isForWorkgroup())
- return failure();
-
- Location loc = op.getLoc();
- VectorType type = op.getResult().getType();
- auto wgShape = type.getShape();
-
- auto wgMaskOperands = op.getOperands();
-
Value sgId =
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
auto sgOffsets =
@@ -1366,17 +1309,17 @@ struct WgToSgVectorCreateMaskOp
VectorType resultType = VectorType::get(sgShape, type.getElementType());
// In each dimension, each subgroup computes its local mask size as:
- // min(max(wgMaskSize[d] - offset[d], 0), sgDimSize[d])
+ // min(max(wgMaskDimSize[d] - offset[d], 0), sgDimSize[d])
SmallVector<Value> newCreateMaskOps;
for (auto offsetSet : *sgOffsets) {
SmallVector<Value> maskOperands;
- for (auto [i, wgMaskOperand] : llvm::enumerate(wgMaskOperands)) {
+ for (auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) {
Value dimSizeVal =
arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
Value offset = offsetSet[i];
Value adjustedMaskSize =
- arith::SubIOp::create(rewriter, loc, wgMaskOperand, offset);
+ arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset);
Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
Value nonNegative =
arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
@@ -1397,6 +1340,8 @@ struct WgToSgVectorCreateMaskOp
}
};
+using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
+using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
} // namespace
namespace mlir {
More information about the Mlir-commits
mailing list