[Mlir-commits] [mlir] [MLIR][XeGPU] Add pattern for arith.constant for wg to sg distribution (PR #151977)

Nishant Patel llvmlistbot at llvm.org
Thu Aug 7 10:50:01 PDT 2025


================
@@ -649,6 +649,52 @@ struct UnrealizedConversionCastOpPattern
   }
 };
 
+// This pattern distributes arith.constant op into subgroup-level constants
+struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
+  using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
+    auto vecType = dyn_cast<VectorType>(op.getType());
+    if (!vecAttr || !vecType)
+      return failure();
+
+    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
+    if (!layout || !layout.getSgLayout())
+      return failure();
+
+    ArrayRef<int64_t> wgShape = vecType.getShape();
+    SmallVector<int64_t> sgShape;
+    int count;
+    std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
+
+    // Current limitation: constant of vector with single value.
+    // TODO: support more complex cases, e.g., vector with multiple values.
+    Attribute singleVal;
+    if (vecAttr.isSplat())
+      singleVal = vecAttr.getSplatValue<Attribute>();
+    else
+      return failure();
+
+    SmallVector<Value> newConsts;
+    auto newType = VectorType::get(sgShape, vecType.getElementType());
+    auto newLayout = layout.dropSgLayoutAndData();
+    for (int i = 0; i < count; ++i) {
----------------
nbpatel wrote:

The loop is to create multiple constants for 1 to N conversion because the users in the other patterns expect it to be different instructions 

For example in this IR, 

%cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} dense<1.0> : vector<256x128xf32>
    %addf = arith.addf %load, %cst {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xf32>
    
The pattern for element wise expects constant to be broken into 4 constant ops, because addf will be broken  into 4 ops 

https://github.com/llvm/llvm-project/pull/151977


More information about the Mlir-commits mailing list