[Mlir-commits] [mlir] [MLIR][XeGPU] Add pattern for arith.constant for wg to sg distribution (PR #151977)

Nishant Patel llvmlistbot at llvm.org
Wed Aug 13 13:30:42 PDT 2025


https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/151977

>From 243bfef2b3e7d4607c162fc889c123af2d7c24e2 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 28 Jul 2025 17:05:28 +0000
Subject: [PATCH 1/4] Add pattern for arith.constant

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 58 ++++++++++++++++++-
 .../test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir |  7 +++
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir   |  7 +++
 3 files changed, 70 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 850f70cca288f..878638061db5c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -649,6 +649,52 @@ struct UnrealizedConversionCastOpPattern
   }
 };
 
+// This pattern distributes arith.constant op into subgroup-level constants
+struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
+  using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
+    auto vecType = dyn_cast<VectorType>(op.getType());
+    if (!vecAttr || !vecType)
+      return failure();
+
+    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
+    if (!layout || !layout.getSgLayout())
+      return failure();
+
+    ArrayRef<int64_t> wgShape = vecType.getShape();
+    SmallVector<int64_t> sgShape;
+    int count;
+    std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
+
+    // Current limitation: constant of vector with single value.
+    // TODO: support more complex cases, e.g., vector with multiple values.
+    Attribute singleVal;
+    if (vecAttr.isSplat())
+      singleVal = vecAttr.getSplatValue<Attribute>();
+    else
+      return failure();
+
+    SmallVector<Value> newConsts;
+    auto newType = VectorType::get(sgShape, vecType.getElementType());
+    auto newLayout = layout.dropSgLayoutAndData();
+    for (int i = 0; i < count; ++i) {
+      auto sgAttr = DenseElementsAttr::get(newType, singleVal);
+      auto cstOp =
+          rewriter.create<arith::ConstantOp>(op.getLoc(), newType, sgAttr);
+      if (newLayout)
+        xegpu::setLayoutAttr(cstOp->getResult(0), newLayout);
+      newConsts.push_back(cstOp);
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newConsts});
+    return success();
+  }
+};
+
 } // namespace
 
 namespace mlir {
@@ -657,8 +703,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
   patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
                WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
                UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
-               WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
-      patterns.getContext());
+               WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
+               WgToSgArithConstantOp>(patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -770,6 +816,14 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         return isLegal(xegpu::getLayoutAttr(op.getResult()));
       });
 
+  target.addDynamicallyLegalOp<arith::ConstantOp>(
+      [=](arith::ConstantOp op) -> bool {
+        auto vecType = dyn_cast<VectorType>(op.getType());
+        if (!vecType)
+          return true;
+        return isLegal(xegpu::getLayoutAttr(op.getResult()));
+      });
+
   target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
       [=](xegpu::ConvertLayoutOp op) -> bool {
         return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index d67bdb487d8bf..65f4b46ad6d26 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -225,4 +225,11 @@ gpu.module @test_round_robin_assignment {
                                    target_layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [8, 16]>}> : vector<32x64xf32>
     gpu.return
   }
+
+  // CHECK-LABEL: distribute_constant
+  gpu.func @distribute_constant() {
+    // CHECK-COUNT-4: arith.constant dense<1.000000e+00> : vector<16x16xf32>
+    %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} dense<1.0> : vector<256x128xf32>
+    gpu.return
+  }
 }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index d51122417fb61..415753a652092 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -393,4 +393,11 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
   } {sg_id_range = #xegpu.range<[3, 19]>}
   gpu.return
   }
+
+  // CHECK-LABEL: distribute_constant
+  gpu.func @distribute_constant() {
+    // CHECK: arith.constant dense<1.000000e+00> : vector<32x32xf32>
+    %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} dense<1.0> : vector<256x128xf32>
+    gpu.return
+  }
 }

>From 3f4b553e7f9bd41d52d96e9351725c7bcfb2b9f0 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 4 Aug 2025 18:43:18 +0000
Subject: [PATCH 2/4] Feedback

---
 .../Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp    | 8 ++------
 1 file changed, 2 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 878638061db5c..a9529a9e4a125 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -658,7 +658,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
                   ConversionPatternRewriter &rewriter) const override {
     auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
     auto vecType = dyn_cast<VectorType>(op.getType());
-    if (!vecAttr || !vecType)
+    if (!vecAttr || !vecAttr.isSplat() || !vecType)
       return failure();
 
     xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
@@ -672,11 +672,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
 
     // Current limitation: constant of vector with single value.
     // TODO: support more complex cases, e.g., vector with multiple values.
-    Attribute singleVal;
-    if (vecAttr.isSplat())
-      singleVal = vecAttr.getSplatValue<Attribute>();
-    else
-      return failure();
+    Attribute singleVal = vecAttr.getSplatValue<Attribute>();
 
     SmallVector<Value> newConsts;
     auto newType = VectorType::get(sgShape, vecType.getElementType());

>From 1e4d8334ba69e96a504c8ad3230b02cbf346d5eb Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 13 Aug 2025 18:53:29 +0000
Subject: [PATCH 3/4] Feedback

---
 .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp   | 16 ++++++----------
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir   |  7 -------
 2 files changed, 6 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index a9529a9e4a125..ead10e64a6ba6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -674,17 +674,13 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
     // TODO: support more complex cases, e.g., vector with multiple values.
     Attribute singleVal = vecAttr.getSplatValue<Attribute>();
 
-    SmallVector<Value> newConsts;
     auto newType = VectorType::get(sgShape, vecType.getElementType());
-    auto newLayout = layout.dropSgLayoutAndData();
-    for (int i = 0; i < count; ++i) {
-      auto sgAttr = DenseElementsAttr::get(newType, singleVal);
-      auto cstOp =
-          rewriter.create<arith::ConstantOp>(op.getLoc(), newType, sgAttr);
-      if (newLayout)
-        xegpu::setLayoutAttr(cstOp->getResult(0), newLayout);
-      newConsts.push_back(cstOp);
-    }
+    auto sgAttr = DenseElementsAttr::get(newType, singleVal);
+    auto cstOp =
+        rewriter.create<arith::ConstantOp>(op.getLoc(), newType, sgAttr);
+    if (auto newLayout = layout.dropSgLayoutAndData())
+      xegpu::setLayoutAttr(cstOp->getResult(0), newLayout);
+    SmallVector<Value> newConsts(count, cstOp);
 
     rewriter.replaceOpWithMultiple(op, {newConsts});
     return success();
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index 65f4b46ad6d26..d67bdb487d8bf 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -225,11 +225,4 @@ gpu.module @test_round_robin_assignment {
                                    target_layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [8, 16]>}> : vector<32x64xf32>
     gpu.return
   }
-
-  // CHECK-LABEL: distribute_constant
-  gpu.func @distribute_constant() {
-    // CHECK-COUNT-4: arith.constant dense<1.000000e+00> : vector<16x16xf32>
-    %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} dense<1.0> : vector<256x128xf32>
-    gpu.return
-  }
 }

>From e9c1517296dbe2a3e1c01d2214bc9e3ee0fc5b86 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 13 Aug 2025 20:30:22 +0000
Subject: [PATCH 4/4] Clang-format

---
 .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp      | 13 ++++++-------
 1 file changed, 6 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index b9008165a36e7..270d71aaa7273 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -690,13 +690,12 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
 namespace mlir {
 namespace xegpu {
 void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
-  patterns
-      .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
-           WgToSgStoreNdOp, WgToSgUpdateNdOffsetOp, WgToSgDpasOp,
-           WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern,
-           WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
-           WgToSgArithConstantOp>(
-          patterns.getContext());
+  patterns.add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
+               WgToSgStoreNdOp, WgToSgUpdateNdOffsetOp, WgToSgDpasOp,
+               WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern,
+               WgToSgElementwiseOp, WgToSgVectorBroadcastOp,
+               WgToSgConvertLayoutOp, WgToSgArithConstantOp>(
+      patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir



More information about the Mlir-commits mailing list