[Mlir-commits] [mlir] af87214 - [MLIR][XeGPU] Add pattern for arith.constant for wg to sg distribution (#151977)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Aug 13 13:52:10 PDT 2025
Author: Nishant Patel
Date: 2025-08-13T13:52:07-07:00
New Revision: af87214b849ba064dbe6dc13972b29cc49662fd2
URL: https://github.com/llvm/llvm-project/commit/af87214b849ba064dbe6dc13972b29cc49662fd2
DIFF: https://github.com/llvm/llvm-project/commit/af87214b849ba064dbe6dc13972b29cc49662fd2.diff
LOG: [MLIR][XeGPU] Add pattern for arith.constant for wg to sg distribution (#151977)
Added:
Modified:
mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 97c97ac3fd680..270d71aaa7273 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -647,17 +647,55 @@ struct UnrealizedConversionCastOpPattern
}
};
+// This pattern distributes arith.constant op into subgroup-level constants
+struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
+ using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
+ auto vecType = dyn_cast<VectorType>(op.getType());
+ if (!vecAttr || !vecAttr.isSplat() || !vecType)
+ return failure();
+
+ xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
+ if (!layout || !layout.getSgLayout())
+ return failure();
+
+ ArrayRef<int64_t> wgShape = vecType.getShape();
+ SmallVector<int64_t> sgShape;
+ int count;
+ std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
+
+ // Current limitation: constant of vector with single value.
+ // TODO: support more complex cases, e.g., vector with multiple values.
+ Attribute singleVal = vecAttr.getSplatValue<Attribute>();
+
+ auto newType = VectorType::get(sgShape, vecType.getElementType());
+ auto sgAttr = DenseElementsAttr::get(newType, singleVal);
+ auto cstOp =
+ rewriter.create<arith::ConstantOp>(op.getLoc(), newType, sgAttr);
+ if (auto newLayout = layout.dropSgLayoutAndData())
+ xegpu::setLayoutAttr(cstOp->getResult(0), newLayout);
+ SmallVector<Value> newConsts(count, cstOp);
+
+ rewriter.replaceOpWithMultiple(op, {newConsts});
+ return success();
+ }
+};
+
} // namespace
namespace mlir {
namespace xegpu {
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
- patterns
- .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
- WgToSgStoreNdOp, WgToSgUpdateNdOffsetOp, WgToSgDpasOp,
- WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern,
- WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
- patterns.getContext());
+ patterns.add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
+ WgToSgStoreNdOp, WgToSgUpdateNdOffsetOp, WgToSgDpasOp,
+ WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern,
+ WgToSgElementwiseOp, WgToSgVectorBroadcastOp,
+ WgToSgConvertLayoutOp, WgToSgArithConstantOp>(
+ patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
@@ -769,6 +807,14 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(xegpu::getLayoutAttr(op.getResult()));
});
+ target.addDynamicallyLegalOp<arith::ConstantOp>(
+ [=](arith::ConstantOp op) -> bool {
+ auto vecType = dyn_cast<VectorType>(op.getType());
+ if (!vecType)
+ return true;
+ return isLegal(xegpu::getLayoutAttr(op.getResult()));
+ });
+
target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
[=](xegpu::ConvertLayoutOp op) -> bool {
return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 180ba8a162c9f..f4a49da71605f 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -365,4 +365,11 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
} {sg_id_range = #xegpu.range<[3, 19]>}
gpu.return
}
+
+ // CHECK-LABEL: distribute_constant
+ gpu.func @distribute_constant() {
+ // CHECK: arith.constant dense<1.000000e+00> : vector<32x32xf32>
+ %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} dense<1.0> : vector<256x128xf32>
+ gpu.return
+ }
}
More information about the Mlir-commits
mailing list