[Mlir-commits] [mlir] [MLIR][XeGPU] Distribute non-splat constant from wg to sg (PR #161416)

Nishant Patel llvmlistbot at llvm.org
Tue Sep 30 13:33:59 PDT 2025


https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/161416

>From 671215094dde611017c0f6c16e9665935820d462 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 26 Sep 2025 17:42:27 +0000
Subject: [PATCH 1/4] Add support for non splatable constant

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 123 +++++++++++++++---
 1 file changed, 107 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 9413a9296b184..8705f4aca0dd1 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -711,7 +711,6 @@ struct UnrealizedConversionCastOpPattern
   }
 };
 
-// This pattern distributes arith.constant op into subgroup-level constants
 struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
   using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
 
@@ -720,7 +719,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
                   ConversionPatternRewriter &rewriter) const override {
     auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
     auto vecType = dyn_cast<VectorType>(op.getType());
-    if (!vecAttr || !vecAttr.isSplat() || !vecType)
+    if (!vecAttr || !vecType)
       return failure();
 
     xegpu::DistributeLayoutAttr layout =
@@ -733,22 +732,114 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
     int count;
     std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
 
-    // Current limitation: constant of vector with single value.
-    // TODO: support more complex cases, e.g., vector with multiple values.
-    Attribute singleVal = vecAttr.getSplatValue<Attribute>();
-
     auto newType = VectorType::get(sgShape, vecType.getElementType());
-    auto sgAttr = DenseElementsAttr::get(newType, singleVal);
-    auto cstOp =
-        arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
-    if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
-        !layout.getEffectiveInstDataAsInt().empty())
-      xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
-                                     layout.dropSgLayoutAndData());
-    SmallVector<Value> newConsts(count, cstOp);
+    Location loc = op.getLoc();
+    auto eltType = vecType.getElementType();
 
-    rewriter.replaceOpWithMultiple(op, {newConsts});
-    return success();
+    auto setLayoutIfNeeded = [&](Value val) {
+      if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
+          !layout.getEffectiveInstDataAsInt().empty()) {
+        xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val),
+                                       layout.dropSgLayoutAndData());
+      }
+    };
+
+    if (vecAttr.isSplat()) {
+      // Splat: single value for all subgroups
+      Attribute singleVal = vecAttr.getSplatValue<Attribute>();
+      auto sgAttr = DenseElementsAttr::get(newType, singleVal);
+      auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
+      setLayoutIfNeeded(cstOp->getResult(0));
+      rewriter.replaceOp(op, cstOp);
+      return success();
+    } else if (sgShape == wgShape) { // if the entire vector is shared by all
+                                     // threads...don't distribute
+      auto newConstOp =
+          arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
+      setLayoutIfNeeded(newConstOp->getResult(0));
+      rewriter.replaceOp(op, newConstOp);
+      return success();
+    } else {
+      // Non-splat constant: use baseValue/stride logic for runtime indexing,
+      // with wrap-around
+      if (wgShape.size() >= 2 && wgShape[0] != 1 && wgShape[1] != 1)
+        return rewriter.notifyMatchFailure(
+            op, "Only 1D or 2D vector constant supported");
+      SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
+      int64_t stride = 0;
+      if (values.size() > 1) {
+        stride = cast<IntegerAttr>(values[1]).getInt() -
+                 cast<IntegerAttr>(values[0]).getInt();
+        for (size_t i = 2; i < values.size(); ++i) {
+          int64_t diff = cast<IntegerAttr>(values[i]).getInt() -
+                         cast<IntegerAttr>(values[i - 1]).getInt();
+          if (diff != stride)
+            return rewriter.notifyMatchFailure(
+                op, "Non-constant stride in non-splat constant op.");
+        }
+      }
+
+      // Create a constant for the first tile
+      SmallVector<Attribute> tileValues;
+      int sgData = 1;
+      if (sgShape.size() == 1) {
+        sgData = static_cast<int>(sgShape[0]);
+      } else if (sgShape.size() == 2) {
+        // If shape is [1, n] or [n, 1], pick the non-1 dimension (n).
+        if (sgShape[0] == 1 && sgShape[1] != 1)
+          sgData = static_cast<int>(sgShape[1]);
+        else
+          sgData = static_cast<int>(sgShape[0]);
+      } else {
+        return rewriter.notifyMatchFailure(
+            op, "Only 1D or 2D vector constant supported");
+      }
+
+      for (int i = 0; i < sgData; ++i)
+        tileValues.push_back(values[i]);
+      auto tileAttr = DenseElementsAttr::get(VectorType::get({sgData}, eltType),
+                                             tileValues);
+      auto baseConstVec = rewriter.create<arith::ConstantOp>(loc, tileAttr);
+
+      // Get subgroup/thread id
+      Value sgId =
+          gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+
+      // Compute baseValue: baseValue = (sgId % numTiles) * stride * sgData
+      int64_t nonUnitDim = 0;
+      if (wgShape.size() == 2)
+        nonUnitDim = wgShape[0] != 1 ? 0 : 1;
+      // For 1D, just use the first dim
+      int64_t numTiles = wgShape[nonUnitDim] / sgShape[nonUnitDim];
+      auto numTileConst =
+          rewriter.create<arith::ConstantIndexOp>(loc, numTiles);
+      Value remsiOp = rewriter.create<arith::RemSIOp>(
+          loc, rewriter.getIndexType(), sgId, numTileConst);
+      auto baseValueConst =
+          rewriter.create<arith::ConstantIndexOp>(loc, stride * sgData);
+      Value baseValue = rewriter.create<arith::MulIOp>(
+          loc, rewriter.getIndexType(), remsiOp, baseValueConst);
+
+      // Broadcast baseValue to vector
+      auto splatBaseValue = rewriter.create<vector::SplatOp>(
+          loc, VectorType::get({sgData}, rewriter.getIndexType()), baseValue);
+
+      // Add baseValue to baseConstantVec constant
+      Value finalTile = rewriter.create<arith::AddIOp>(
+          loc, splatBaseValue->getResult(0), baseConstVec);
+
+      // Cast to final type if needed
+      Value result;
+      if (eltType.isIndex()) {
+        result = finalTile;
+      } else {
+        result = rewriter.create<arith::IndexCastOp>(loc, newType, finalTile);
+      }
+
+      setLayoutIfNeeded(result);
+      rewriter.replaceOp(op, result);
+      return success();
+    }
   }
 };
 

>From 7d3746a09bc77c997382cb137c3a8b2326d28c6f Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 30 Sep 2025 06:00:50 +0000
Subject: [PATCH 2/4] Support 1:N conversion

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 72 +++++++++----------
 .../XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir    | 28 ++++++++
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       | 20 ++++++
 3 files changed, 80 insertions(+), 40 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 8705f4aca0dd1..2bbf1a85bb5be 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -753,18 +753,22 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
       rewriter.replaceOp(op, cstOp);
       return success();
     } else if (sgShape == wgShape) { // if the entire vector is shared by all
-                                     // threads...don't distribute
+                                     // subgroups...don't distribute
       auto newConstOp =
           arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
       setLayoutIfNeeded(newConstOp->getResult(0));
       rewriter.replaceOp(op, newConstOp);
       return success();
     } else {
-      // Non-splat constant: use baseValue/stride logic for runtime indexing,
-      // with wrap-around
-      if (wgShape.size() >= 2 && wgShape[0] != 1 && wgShape[1] != 1)
+      // Non-splat constant
+      if (wgShape.size() > 2)
         return rewriter.notifyMatchFailure(
-            op, "Only 1D or 2D vector constant supported");
+            op, "Only 1D & 2D vector constant supported");
+
+      if (wgShape.size() == 2 && wgShape[0] != 1 && wgShape[1] != 1)
+        return rewriter.notifyMatchFailure(
+            op, "2D vector constant only supported with 1 unit dim");
+
       SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
       int64_t stride = 0;
       if (values.size() > 1) {
@@ -779,13 +783,13 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
         }
       }
 
-      // Create a constant for the first tile
+      // Create a constant for the base tile
       SmallVector<Attribute> tileValues;
       int sgData = 1;
       if (sgShape.size() == 1) {
         sgData = static_cast<int>(sgShape[0]);
       } else if (sgShape.size() == 2) {
-        // If shape is [1, n] or [n, 1], pick the non-1 dimension (n).
+        // If shape is [1, n] or [n, 1], pick the non-unit dimension.
         if (sgShape[0] == 1 && sgShape[1] != 1)
           sgData = static_cast<int>(sgShape[1]);
         else
@@ -801,43 +805,31 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
                                              tileValues);
       auto baseConstVec = rewriter.create<arith::ConstantOp>(loc, tileAttr);
 
-      // Get subgroup/thread id
+      // Get subgroup id
       Value sgId =
           gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
 
-      // Compute baseValue: baseValue = (sgId % numTiles) * stride * sgData
-      int64_t nonUnitDim = 0;
-      if (wgShape.size() == 2)
-        nonUnitDim = wgShape[0] != 1 ? 0 : 1;
-      // For 1D, just use the first dim
-      int64_t numTiles = wgShape[nonUnitDim] / sgShape[nonUnitDim];
-      auto numTileConst =
-          rewriter.create<arith::ConstantIndexOp>(loc, numTiles);
-      Value remsiOp = rewriter.create<arith::RemSIOp>(
-          loc, rewriter.getIndexType(), sgId, numTileConst);
-      auto baseValueConst =
-          rewriter.create<arith::ConstantIndexOp>(loc, stride * sgData);
-      Value baseValue = rewriter.create<arith::MulIOp>(
-          loc, rewriter.getIndexType(), remsiOp, baseValueConst);
-
-      // Broadcast baseValue to vector
-      auto splatBaseValue = rewriter.create<vector::SplatOp>(
-          loc, VectorType::get({sgData}, rewriter.getIndexType()), baseValue);
-
-      // Add baseValue to baseConstantVec constant
-      Value finalTile = rewriter.create<arith::AddIOp>(
-          loc, splatBaseValue->getResult(0), baseConstVec);
-
-      // Cast to final type if needed
-      Value result;
-      if (eltType.isIndex()) {
-        result = finalTile;
-      } else {
-        result = rewriter.create<arith::IndexCastOp>(loc, newType, finalTile);
+      auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+      if (failed(sgOffsets))
+        return failure();
+
+      SmallVector<Value> newConstOps;
+      for (auto offsets : *sgOffsets) {
+        // Multiply offset with stride and broadcast it to a vector of
+        // "sgData[nonUnitDim]" size
+        auto strideConst = rewriter.create<arith::ConstantIndexOp>(loc, stride);
+        Value mulOffset = rewriter.create<arith::MulIOp>(
+            loc, rewriter.getIndexType(), offsets[0], strideConst);
+        auto bcastOffset = rewriter.create<vector::SplatOp>(
+            loc, VectorType::get({sgData}, rewriter.getIndexType()), mulOffset);
+        auto finalConst =
+            arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
+        setLayoutIfNeeded(baseConstVec);
+        setLayoutIfNeeded(bcastOffset);
+        setLayoutIfNeeded(finalConst);
+        newConstOps.push_back(finalConst);
       }
-
-      setLayoutIfNeeded(result);
-      rewriter.replaceOp(op, result);
+      rewriter.replaceOpWithMultiple(op, {newConstOps});
       return success();
     }
   }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index dce73dee507e1..271d2b2f908fb 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -98,4 +98,32 @@ gpu.module @test_distribution {
       : vector<256x64xf32> to vector<256xf32>
     gpu.return
   }
+
+  gpu.func @non_splat_constant() {
+    // CHECK-DAG: %[[CST:.*]] = arith.constant dense<[0, 16]> : vector<2xindex>
+    // CHECK-DAG: %[[SG_ID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[AFF1:.*]] = affine.apply #map4()[%[[SG_ID]]]
+    // CHECK-DAG: %[[AFF2:.*]] = affine.apply #map5()[%[[SG_ID]]]
+    // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+    // CHECK-DAG: %[[MUL:.*]] = index.mul %[[AFF1]], %[[C2]]
+    // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+    // CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
+    // CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[MUL]], %[[C0]] : index
+    // CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[AFF2]], %[[C0_0]] : index
+    // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+    // CHECK-DAG: %[[REM:.*]] = index.remu %[[ADD1]], %[[C32]]
+    // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+    // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+    // CHECK-DAG: %[[C16_0:.*]] = arith.constant 16 : index
+    // CHECK-DAG: %[[C16_1:.*]] = arith.constant 16 : index
+    // CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[MUL]], %[[C16]] : index
+    // CHECK-DAG: %[[REM2:.*]] = index.remu %[[ADD3]], %[[C32]]
+    // CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[REM]], %[[C16_0]] : index
+    // CHECK-DAG: %[[SPLAT:.*]] = vector.splat %[[MUL2]] : vector<2xindex>
+    // CHECK-DAG: %[[MUL3:.*]] = arith.muli %[[REM2]], %[[C16_1]] : index
+    // CHECK-DAG: %[[SPLAT2:.*]] = vector.splat %[[MUL3]] : vector<2xindex>
+    // CHECK-DAG: %[[ADD4:.*]] = arith.addi %[[CST]], %[[SPLAT2]] : vector<2xindex>
+    %cst_2 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [2, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
+    gpu.return
+  }
 }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 48fc633974e63..07b1e0f9ba8db 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -464,4 +464,24 @@ gpu.module @test_distribution {
     %broadcast = vector.broadcast %muli {layout_result_0 = #xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>} : index to vector<4x2x6x32xindex>
     gpu.return
   }
+
+  // CHECK-LABEL: non_splat_constant
+  gpu.func @non_splat_constant() {
+    // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+    // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[IDY:.*]] = affine.apply #map4()[%[[SGID]]]
+    // CHECK-DAG: %[[IDX:.*]] = affine.apply #map5()[%[[SGID]]]
+    // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+    // CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
+    // CHECK-DAG: %[[ADDY:.*]] = arith.addi %[[IDY]], %[[C0]] : index
+    // CHECK-DAG: %[[ADDX:.*]] = arith.addi %[[IDX]], %[[C0_0]] : index
+    // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+    // CHECK-DAG: %[[REMU_Y:.*]] = index.remu %[[ADDY]], %[[C32]]
+    // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+    // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+    // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU_Y]], %[[C16]] : index
+    // CHECK-DAG: %[[SPLAT:.*]] = vector.splat %[[MUL]] : vector<1xindex>
+    %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [1, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
+    gpu.return
+  }
 }

>From 1b00dc76ebca1a475d1345387fd3edef0c34b659 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 30 Sep 2025 16:49:26 +0000
Subject: [PATCH 3/4] All cases work

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 43 ++++++++++++-------
 .../XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir    |  3 +-
 2 files changed, 28 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 2bbf1a85bb5be..be03e6e050c43 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -711,6 +711,7 @@ struct UnrealizedConversionCastOpPattern
   }
 };
 
+// This pattern distributes arith.constant op into subgroup-level constants
 struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
   using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
 
@@ -753,7 +754,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
       rewriter.replaceOp(op, cstOp);
       return success();
     } else if (sgShape == wgShape) { // if the entire vector is shared by all
-                                     // subgroups...don't distribute
+                                     // subgroups, don't distribute
       auto newConstOp =
           arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
       setLayoutIfNeeded(newConstOp->getResult(0));
@@ -761,13 +762,28 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
       return success();
     } else {
       // Non-splat constant
+      // Only supports 1D & 2D (with one unit dim)
+      // TODO: support other cases that require SLM access
+      if (!eltType.isIndex())
+        return rewriter.notifyMatchFailure(
+            op, "Unsupported element type for non-splat constant op.");
+
+      SmallVector<int64_t> sgLayout = layout.getEffectiveSgLayoutAsInt();
       if (wgShape.size() > 2)
         return rewriter.notifyMatchFailure(
             op, "Only 1D & 2D vector constant supported");
 
-      if (wgShape.size() == 2 && wgShape[0] != 1 && wgShape[1] != 1)
+      // allow 2D vector/distributions with one unit dim
+      auto hasTwoNonUnitDims = [](ArrayRef<int64_t> dims) {
+        return dims.size() == 2 && dims[0] != 1 && dims[1] != 1;
+      };
+      if (hasTwoNonUnitDims(wgShape) || hasTwoNonUnitDims(sgLayout))
         return rewriter.notifyMatchFailure(
-            op, "2D vector constant only supported with 1 unit dim");
+            op, "2D vector/distribution only supported with 1 unit dim");
+
+      int64_t nonUnitDim = 0;
+      if (wgShape.size() == 2)
+        nonUnitDim = wgShape[0] != 1 ? 0 : 1;
 
       SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
       int64_t stride = 0;
@@ -783,26 +799,22 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
         }
       }
 
-      // Create a constant for the base tile
-      SmallVector<Attribute> tileValues;
       int sgData = 1;
       if (sgShape.size() == 1) {
         sgData = static_cast<int>(sgShape[0]);
       } else if (sgShape.size() == 2) {
-        // If shape is [1, n] or [n, 1], pick the non-unit dimension.
-        if (sgShape[0] == 1 && sgShape[1] != 1)
-          sgData = static_cast<int>(sgShape[1]);
-        else
-          sgData = static_cast<int>(sgShape[0]);
+        sgData = static_cast<int>(sgShape[0] != 1 ? sgShape[0] : sgShape[1]);
       } else {
         return rewriter.notifyMatchFailure(
             op, "Only 1D or 2D vector constant supported");
       }
 
+      // Create a constant for the base tile
+      SmallVector<Attribute> baseTileValues;
       for (int i = 0; i < sgData; ++i)
-        tileValues.push_back(values[i]);
+        baseTileValues.push_back(values[i]);
       auto tileAttr = DenseElementsAttr::get(VectorType::get({sgData}, eltType),
-                                             tileValues);
+                                             baseTileValues);
       auto baseConstVec = rewriter.create<arith::ConstantOp>(loc, tileAttr);
 
       // Get subgroup id
@@ -813,13 +825,12 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
       if (failed(sgOffsets))
         return failure();
 
+      auto strideConst = rewriter.create<arith::ConstantIndexOp>(loc, stride);
       SmallVector<Value> newConstOps;
       for (auto offsets : *sgOffsets) {
-        // Multiply offset with stride and broadcast it to a vector of
-        // "sgData[nonUnitDim]" size
-        auto strideConst = rewriter.create<arith::ConstantIndexOp>(loc, stride);
+        // Multiply offset with stride, broadcast it and add to baseConstVec
         Value mulOffset = rewriter.create<arith::MulIOp>(
-            loc, rewriter.getIndexType(), offsets[0], strideConst);
+            loc, rewriter.getIndexType(), offsets[nonUnitDim], strideConst);
         auto bcastOffset = rewriter.create<vector::SplatOp>(
             loc, VectorType::get({sgData}, rewriter.getIndexType()), mulOffset);
         auto finalConst =
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index 271d2b2f908fb..f3e2e41ae4b65 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -115,12 +115,11 @@ gpu.module @test_distribution {
     // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
     // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
     // CHECK-DAG: %[[C16_0:.*]] = arith.constant 16 : index
-    // CHECK-DAG: %[[C16_1:.*]] = arith.constant 16 : index
     // CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[MUL]], %[[C16]] : index
     // CHECK-DAG: %[[REM2:.*]] = index.remu %[[ADD3]], %[[C32]]
     // CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[REM]], %[[C16_0]] : index
     // CHECK-DAG: %[[SPLAT:.*]] = vector.splat %[[MUL2]] : vector<2xindex>
-    // CHECK-DAG: %[[MUL3:.*]] = arith.muli %[[REM2]], %[[C16_1]] : index
+    // CHECK-DAG: %[[MUL3:.*]] = arith.muli %[[REM2]], %[[C16_0]] : index
     // CHECK-DAG: %[[SPLAT2:.*]] = vector.splat %[[MUL3]] : vector<2xindex>
     // CHECK-DAG: %[[ADD4:.*]] = arith.addi %[[CST]], %[[SPLAT2]] : vector<2xindex>
     %cst_2 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [2, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>

>From 1381174b29812c6db58c46038aba2b718d9c9072 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 30 Sep 2025 20:33:23 +0000
Subject: [PATCH 4/4] Fix CHECKS

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      |  2 +-
 .../XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir    | 34 ++++++++++++-------
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       | 13 ++++---
 3 files changed, 28 insertions(+), 21 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index be03e6e050c43..9807cb98a5a83 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -831,7 +831,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
         // Multiply offset with stride, broadcast it and add to baseConstVec
         Value mulOffset = rewriter.create<arith::MulIOp>(
             loc, rewriter.getIndexType(), offsets[nonUnitDim], strideConst);
-        auto bcastOffset = rewriter.create<vector::SplatOp>(
+        auto bcastOffset = rewriter.create<vector::BroadcastOp>(
             loc, VectorType::get({sgData}, rewriter.getIndexType()), mulOffset);
         auto finalConst =
             arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index f3e2e41ae4b65..9958d4ef6c1e2 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -102,26 +102,34 @@ gpu.module @test_distribution {
   gpu.func @non_splat_constant() {
     // CHECK-DAG: %[[CST:.*]] = arith.constant dense<[0, 16]> : vector<2xindex>
     // CHECK-DAG: %[[SG_ID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+    // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+    // CHECK-DAG: %[[C1_0:.*]] = arith.constant 1 : index
     // CHECK-DAG: %[[AFF1:.*]] = affine.apply #map4()[%[[SG_ID]]]
     // CHECK-DAG: %[[AFF2:.*]] = affine.apply #map5()[%[[SG_ID]]]
     // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
     // CHECK-DAG: %[[MUL:.*]] = index.mul %[[AFF1]], %[[C2]]
+    // CHECK-DAG: %[[C1_1:.*]] = arith.constant 1 : index
     // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-    // CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
-    // CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[MUL]], %[[C0]] : index
-    // CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[AFF2]], %[[C0_0]] : index
+    // CHECK-DAG: %[[C0_2:.*]] = arith.constant 0 : index
     // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
-    // CHECK-DAG: %[[REM:.*]] = index.remu %[[ADD1]], %[[C32]]
-    // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+    // CHECK-DAG: %[[REM:.*]] = index.remu %[[MUL]], %[[C32]]
+    // CHECK-DAG: %[[C1_3:.*]] = arith.constant 1 : index
+    // CHECK-DAG: %[[REM2:.*]] = index.remu %[[AFF2]], %[[C1_3]]
     // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
-    // CHECK-DAG: %[[C16_0:.*]] = arith.constant 16 : index
-    // CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[MUL]], %[[C16]] : index
-    // CHECK-DAG: %[[REM2:.*]] = index.remu %[[ADD3]], %[[C32]]
-    // CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[REM]], %[[C16_0]] : index
-    // CHECK-DAG: %[[SPLAT:.*]] = vector.splat %[[MUL2]] : vector<2xindex>
-    // CHECK-DAG: %[[MUL3:.*]] = arith.muli %[[REM2]], %[[C16_0]] : index
-    // CHECK-DAG: %[[SPLAT2:.*]] = vector.splat %[[MUL3]] : vector<2xindex>
-    // CHECK-DAG: %[[ADD4:.*]] = arith.addi %[[CST]], %[[SPLAT2]] : vector<2xindex>
+    // CHECK-DAG: %[[C0_4:.*]] = arith.constant 0 : index
+    // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[MUL]], %[[C16]] : index
+    // CHECK-DAG: %[[C32_5:.*]] = arith.constant 32 : index
+    // CHECK-DAG: %[[REM3:.*]] = index.remu %[[ADD]], %[[C32_5]]
+    // CHECK-DAG: %[[C1_6:.*]] = arith.constant 1 : index
+    // CHECK-DAG: %[[REM4:.*]] = index.remu %[[AFF2]], %[[C1_6]]
+    // CHECK-DAG: %[[C16_7:.*]] = arith.constant 16 : index
+    // CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[REM]], %[[C16_7]] : index
+    // CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[MUL2]] : index to vector<2xindex>
+    // CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[CST]], %[[BCAST]] : vector<2xindex>
+    // CHECK-DAG: %[[MUL3:.*]] = arith.muli %[[REM3]], %[[C16_7]] : index
+    // CHECK-DAG: %[[BCAST2:.*]] = vector.broadcast %[[MUL3]] : index to vector<2xindex>
+    // CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[CST]], %[[BCAST2]] : vector<2xindex>
     %cst_2 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [2, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
     gpu.return
   }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 85b78bb41db08..a2203f8e945d2 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -463,18 +463,17 @@ gpu.module @test_distribution {
   gpu.func @non_splat_constant() {
     // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
     // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+    // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
     // CHECK-DAG: %[[IDY:.*]] = affine.apply #map4()[%[[SGID]]]
     // CHECK-DAG: %[[IDX:.*]] = affine.apply #map5()[%[[SGID]]]
     // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-    // CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
-    // CHECK-DAG: %[[ADDY:.*]] = arith.addi %[[IDY]], %[[C0]] : index
-    // CHECK-DAG: %[[ADDX:.*]] = arith.addi %[[IDX]], %[[C0_0]] : index
-    // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
-    // CHECK-DAG: %[[REMU_Y:.*]] = index.remu %[[ADDY]], %[[C32]]
-    // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+    // CHECK-DAG: %[[REMU_Y:.*]] = index.remu %[[IDY]], %[[C32]]
+    // CHECK-DAG: %[[REMU_X:.*]] = index.remu %[[IDX]], %[[C1]]
     // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
     // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU_Y]], %[[C16]] : index
-    // CHECK-DAG: %[[SPLAT:.*]] = vector.splat %[[MUL]] : vector<1xindex>
+    // CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[MUL]] : index to vector<1xindex>
+    // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[CST]], %[[BCAST]] : vector<1xindex>
     %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [1, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
     gpu.return
   }



More information about the Mlir-commits mailing list