[Mlir-commits] [mlir] [MLIR][XeGPU] Add pattern for arith.constant for wg to sg distribution (PR #151977)
Nishant Patel
llvmlistbot at llvm.org
Mon Aug 4 11:43:38 PDT 2025
https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/151977
>From 243bfef2b3e7d4607c162fc889c123af2d7c24e2 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 28 Jul 2025 17:05:28 +0000
Subject: [PATCH 1/2] Add pattern for arith.constant
---
.../Transforms/XeGPUWgToSgDistribute.cpp | 58 ++++++++++++++++++-
.../test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir | 7 +++
mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir | 7 +++
3 files changed, 70 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 850f70cca288f..878638061db5c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -649,6 +649,52 @@ struct UnrealizedConversionCastOpPattern
}
};
+// This pattern distributes arith.constant op into subgroup-level constants
+struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
+ using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
+ auto vecType = dyn_cast<VectorType>(op.getType());
+ if (!vecAttr || !vecType)
+ return failure();
+
+ xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
+ if (!layout || !layout.getSgLayout())
+ return failure();
+
+ ArrayRef<int64_t> wgShape = vecType.getShape();
+ SmallVector<int64_t> sgShape;
+ int count;
+ std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
+
+ // Current limitation: constant of vector with single value.
+ // TODO: support more complex cases, e.g., vector with multiple values.
+ Attribute singleVal;
+ if (vecAttr.isSplat())
+ singleVal = vecAttr.getSplatValue<Attribute>();
+ else
+ return failure();
+
+ SmallVector<Value> newConsts;
+ auto newType = VectorType::get(sgShape, vecType.getElementType());
+ auto newLayout = layout.dropSgLayoutAndData();
+ for (int i = 0; i < count; ++i) {
+ auto sgAttr = DenseElementsAttr::get(newType, singleVal);
+ auto cstOp =
+ rewriter.create<arith::ConstantOp>(op.getLoc(), newType, sgAttr);
+ if (newLayout)
+ xegpu::setLayoutAttr(cstOp->getResult(0), newLayout);
+ newConsts.push_back(cstOp);
+ }
+
+ rewriter.replaceOpWithMultiple(op, {newConsts});
+ return success();
+ }
+};
+
} // namespace
namespace mlir {
@@ -657,8 +703,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
- WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
- patterns.getContext());
+ WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
+ WgToSgArithConstantOp>(patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
@@ -770,6 +816,14 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(xegpu::getLayoutAttr(op.getResult()));
});
+ target.addDynamicallyLegalOp<arith::ConstantOp>(
+ [=](arith::ConstantOp op) -> bool {
+ auto vecType = dyn_cast<VectorType>(op.getType());
+ if (!vecType)
+ return true;
+ return isLegal(xegpu::getLayoutAttr(op.getResult()));
+ });
+
target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
[=](xegpu::ConvertLayoutOp op) -> bool {
return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index d67bdb487d8bf..65f4b46ad6d26 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -225,4 +225,11 @@ gpu.module @test_round_robin_assignment {
target_layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [8, 16]>}> : vector<32x64xf32>
gpu.return
}
+
+ // CHECK-LABEL: distribute_constant
+ gpu.func @distribute_constant() {
+ // CHECK-COUNT-4: arith.constant dense<1.000000e+00> : vector<16x16xf32>
+ %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} dense<1.0> : vector<256x128xf32>
+ gpu.return
+ }
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index d51122417fb61..415753a652092 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -393,4 +393,11 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
} {sg_id_range = #xegpu.range<[3, 19]>}
gpu.return
}
+
+ // CHECK-LABEL: distribute_constant
+ gpu.func @distribute_constant() {
+ // CHECK: arith.constant dense<1.000000e+00> : vector<32x32xf32>
+ %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} dense<1.0> : vector<256x128xf32>
+ gpu.return
+ }
}
>From 3f4b553e7f9bd41d52d96e9351725c7bcfb2b9f0 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 4 Aug 2025 18:43:18 +0000
Subject: [PATCH 2/2] Feedback
---
.../Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 8 ++------
1 file changed, 2 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 878638061db5c..a9529a9e4a125 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -658,7 +658,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
ConversionPatternRewriter &rewriter) const override {
auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
auto vecType = dyn_cast<VectorType>(op.getType());
- if (!vecAttr || !vecType)
+ if (!vecAttr || !vecAttr.isSplat() || !vecType)
return failure();
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
@@ -672,11 +672,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
// Current limitation: constant of vector with single value.
// TODO: support more complex cases, e.g., vector with multiple values.
- Attribute singleVal;
- if (vecAttr.isSplat())
- singleVal = vecAttr.getSplatValue<Attribute>();
- else
- return failure();
+ Attribute singleVal = vecAttr.getSplatValue<Attribute>();
SmallVector<Value> newConsts;
auto newType = VectorType::get(sgShape, vecType.getElementType());
More information about the Mlir-commits
mailing list