[Mlir-commits] [mlir] [MLIR][XeGPU] Add pattern for arith.constant for wg to sg distribution (PR #151977)
Chao Chen
llvmlistbot at llvm.org
Mon Aug 4 07:49:29 PDT 2025
================
@@ -649,6 +649,52 @@ struct UnrealizedConversionCastOpPattern
}
};
+// This pattern distributes arith.constant op into subgroup-level constants
+struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
+ using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
+ auto vecType = dyn_cast<VectorType>(op.getType());
+ if (!vecAttr || !vecType)
+ return failure();
+
+ xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
+ if (!layout || !layout.getSgLayout())
+ return failure();
+
+ ArrayRef<int64_t> wgShape = vecType.getShape();
+ SmallVector<int64_t> sgShape;
+ int count;
+ std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
+
+ // Current limitation: constant of vector with single value.
+ // TODO: support more complex cases, e.g., vector with multiple values.
+ Attribute singleVal;
+ if (vecAttr.isSplat())
+ singleVal = vecAttr.getSplatValue<Attribute>();
+ else
+ return failure();
+
+ SmallVector<Value> newConsts;
+ auto newType = VectorType::get(sgShape, vecType.getElementType());
+ auto newLayout = layout.dropSgLayoutAndData();
+ for (int i = 0; i < count; ++i) {
----------------
chencha3 wrote:
This loop seems not needed. can use `SmallVector<Value> newConsts(cst, count)`
https://github.com/llvm/llvm-project/pull/151977
More information about the Mlir-commits
mailing list