[Mlir-commits] [mlir] af87214 - [MLIR][XeGPU] Add pattern for arith.constant for wg to sg distribution (#151977)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Aug 13 13:52:10 PDT 2025


Author: Nishant Patel
Date: 2025-08-13T13:52:07-07:00
New Revision: af87214b849ba064dbe6dc13972b29cc49662fd2

URL: https://github.com/llvm/llvm-project/commit/af87214b849ba064dbe6dc13972b29cc49662fd2
DIFF: https://github.com/llvm/llvm-project/commit/af87214b849ba064dbe6dc13972b29cc49662fd2.diff

LOG: [MLIR][XeGPU] Add pattern for arith.constant for wg to sg distribution (#151977)

Added: 
    

Modified: 
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
    mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 97c97ac3fd680..270d71aaa7273 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -647,17 +647,55 @@ struct UnrealizedConversionCastOpPattern
   }
 };
 
+// This pattern distributes arith.constant op into subgroup-level constants
+struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
+  using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
+    auto vecType = dyn_cast<VectorType>(op.getType());
+    if (!vecAttr || !vecAttr.isSplat() || !vecType)
+      return failure();
+
+    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
+    if (!layout || !layout.getSgLayout())
+      return failure();
+
+    ArrayRef<int64_t> wgShape = vecType.getShape();
+    SmallVector<int64_t> sgShape;
+    int count;
+    std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
+
+    // Current limitation: constant of vector with single value.
+    // TODO: support more complex cases, e.g., vector with multiple values.
+    Attribute singleVal = vecAttr.getSplatValue<Attribute>();
+
+    auto newType = VectorType::get(sgShape, vecType.getElementType());
+    auto sgAttr = DenseElementsAttr::get(newType, singleVal);
+    auto cstOp =
+        rewriter.create<arith::ConstantOp>(op.getLoc(), newType, sgAttr);
+    if (auto newLayout = layout.dropSgLayoutAndData())
+      xegpu::setLayoutAttr(cstOp->getResult(0), newLayout);
+    SmallVector<Value> newConsts(count, cstOp);
+
+    rewriter.replaceOpWithMultiple(op, {newConsts});
+    return success();
+  }
+};
+
 } // namespace
 
 namespace mlir {
 namespace xegpu {
 void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
-  patterns
-      .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
-           WgToSgStoreNdOp, WgToSgUpdateNdOffsetOp, WgToSgDpasOp,
-           WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern,
-           WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
-          patterns.getContext());
+  patterns.add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
+               WgToSgStoreNdOp, WgToSgUpdateNdOffsetOp, WgToSgDpasOp,
+               WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern,
+               WgToSgElementwiseOp, WgToSgVectorBroadcastOp,
+               WgToSgConvertLayoutOp, WgToSgArithConstantOp>(
+      patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -769,6 +807,14 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         return isLegal(xegpu::getLayoutAttr(op.getResult()));
       });
 
+  target.addDynamicallyLegalOp<arith::ConstantOp>(
+      [=](arith::ConstantOp op) -> bool {
+        auto vecType = dyn_cast<VectorType>(op.getType());
+        if (!vecType)
+          return true;
+        return isLegal(xegpu::getLayoutAttr(op.getResult()));
+      });
+
   target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
       [=](xegpu::ConvertLayoutOp op) -> bool {
         return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());

diff  --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 180ba8a162c9f..f4a49da71605f 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -365,4 +365,11 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
   } {sg_id_range = #xegpu.range<[3, 19]>}
   gpu.return
   }
+
+  // CHECK-LABEL: distribute_constant
+  gpu.func @distribute_constant() {
+    // CHECK: arith.constant dense<1.000000e+00> : vector<32x32xf32>
+    %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} dense<1.0> : vector<256x128xf32>
+    gpu.return
+  }
 }


        


More information about the Mlir-commits mailing list