[Mlir-commits] [mlir] 3c7873b - [MLIR][XeGPU] Distribute non-splat constant from wg to sg (#161416)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 9 10:32:49 PDT 2025


Author: Nishant Patel
Date: 2025-10-09T10:32:45-07:00
New Revision: 3c7873b75f2454be845af1fe2b601e0e8f791b4a

URL: https://github.com/llvm/llvm-project/commit/3c7873b75f2454be845af1fe2b601e0e8f791b4a
DIFF: https://github.com/llvm/llvm-project/commit/3c7873b75f2454be845af1fe2b601e0e8f791b4a.diff

LOG: [MLIR][XeGPU] Distribute non-splat constant from wg to sg (#161416)

This PR distributes non-splat constant from wg to sg. The current
pattern has limitations and avoids cases which require SLM access.

Added: 
    

Modified: 
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
    mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
    mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 784e5d68ce885..c28d2fc6c2b63 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -720,7 +720,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
                   ConversionPatternRewriter &rewriter) const override {
     auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
     auto vecType = dyn_cast<VectorType>(op.getType());
-    if (!vecAttr || !vecAttr.isSplat() || !vecType)
+    if (!vecAttr || !vecType)
       return failure();
 
     xegpu::DistributeLayoutAttr layout =
@@ -733,22 +733,139 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
     int count;
     std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
 
-    // Current limitation: constant of vector with single value.
-    // TODO: support more complex cases, e.g., vector with multiple values.
-    Attribute singleVal = vecAttr.getSplatValue<Attribute>();
-
     auto newType = VectorType::get(sgShape, vecType.getElementType());
-    auto sgAttr = DenseElementsAttr::get(newType, singleVal);
-    auto cstOp =
-        arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
-    if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
-        !layout.getEffectiveInstDataAsInt().empty())
-      xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
-                                     layout.dropSgLayoutAndData());
-    SmallVector<Value> newConsts(count, cstOp);
+    Location loc = op.getLoc();
+    auto eltType = vecType.getElementType();
 
-    rewriter.replaceOpWithMultiple(op, {newConsts});
-    return success();
+    auto setLayoutIfNeeded = [&](Value val) {
+      if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
+          !layout.getEffectiveInstDataAsInt().empty()) {
+        xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val),
+                                       layout.dropSgLayoutAndData());
+      }
+    };
+
+    if (vecAttr.isSplat()) {
+      // Splat: single value for all subgroups
+      Attribute singleVal = vecAttr.getSplatValue<Attribute>();
+      auto sgAttr = DenseElementsAttr::get(newType, singleVal);
+      auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
+      setLayoutIfNeeded(cstOp->getResult(0));
+      rewriter.replaceOp(op, cstOp);
+      return success();
+    } else if (sgShape == wgShape) { // if the entire vector is shared by all
+                                     // subgroups, don't distribute
+      auto newConstOp =
+          arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
+      setLayoutIfNeeded(newConstOp->getResult(0));
+      rewriter.replaceOp(op, newConstOp);
+      return success();
+    } else {
+      // Non-splat constant
+      // Only supports 1D & 2D
+      // TODO: support other cases that require SLM access
+      if (!eltType.isIndex())
+        return rewriter.notifyMatchFailure(
+            op, "Unsupported element type for non-splat constant op.");
+
+      if (wgShape.size() > 2)
+        return rewriter.notifyMatchFailure(
+            op, "Only 1D & 2D vector constant supported");
+
+      SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
+      int64_t rowStride = 0, colStride = 0;
+      int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0];
+      int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1];
+
+      // Compute colStride and rowStride, and check for constant strides.
+      if (cols > 1) {
+        colStride = cast<IntegerAttr>(values[1]).getInt() -
+                    cast<IntegerAttr>(values[0]).getInt();
+      }
+      if (rows > 1) {
+        rowStride = cast<IntegerAttr>(values[cols]).getInt() -
+                    cast<IntegerAttr>(values[0]).getInt();
+      }
+
+      for (int64_t r = 0; r < rows; ++r) {
+        for (int64_t c = 0; c < cols; ++c) {
+          int64_t idx = r * cols + c;
+          // Check column stride
+          if (c > 0 && cols > 1) {
+            int64_t prevIdx = r * cols + (c - 1);
+            int64_t 
diff  = cast<IntegerAttr>(values[idx]).getInt() -
+                           cast<IntegerAttr>(values[prevIdx]).getInt();
+            if (
diff  != colStride)
+              return rewriter.notifyMatchFailure(
+                  op, "Non-constant column stride in constant op.");
+          }
+          // Check row stride
+          if (r > 0 && rows > 1) {
+            int64_t prevIdx = (r - 1) * cols + c;
+            int64_t 
diff  = cast<IntegerAttr>(values[idx]).getInt() -
+                           cast<IntegerAttr>(values[prevIdx]).getInt();
+            if (
diff  != rowStride)
+              return rewriter.notifyMatchFailure(
+                  op, "Non-constant row stride in constant op.");
+          }
+        }
+      }
+
+      // Create a constant for the base tile.
+      // For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix.
+      // For 1D case, extract the first sgShape[0] elements.
+      SmallVector<Attribute> baseTileValues;
+      int baseTileCols = sgShape[sgShape.size() - 1];
+      int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0];
+      for (int64_t r = 0; r < baseTileRows; ++r) {
+        for (int64_t c = 0; c < baseTileCols; ++c) {
+          baseTileValues.push_back(values[r * cols + c]);
+        }
+      }
+
+      auto tileAttr = DenseElementsAttr::get(VectorType::get(sgShape, eltType),
+                                             baseTileValues);
+      auto baseConstVec = rewriter.create<arith::ConstantOp>(loc, tileAttr);
+
+      // Get subgroup id
+      Value sgId =
+          gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+
+      auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+      if (failed(sgOffsets))
+        return failure();
+
+      SmallVector<Value, 2> strideConsts;
+      strideConsts.push_back(
+          rewriter.create<arith::ConstantIndexOp>(loc, colStride));
+      if (rows > 1)
+        strideConsts.insert(
+            strideConsts.begin(),
+            rewriter.create<arith::ConstantIndexOp>(loc, rowStride));
+
+      SmallVector<Value> newConstOps;
+      for (auto offsets : *sgOffsets) {
+        // Multiply offset with stride, broadcast it and add to baseConstVec
+        Value mulOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+        for (size_t i = 0; i < strideConsts.size(); ++i) {
+          Value mul = rewriter.create<arith::MulIOp>(
+              loc, rewriter.getIndexType(), offsets[i], strideConsts[i]);
+          mulOffset = rewriter.create<arith::AddIOp>(
+              loc, rewriter.getIndexType(), mulOffset, mul);
+        }
+        // Broadcast to baseConstVec size
+        auto bcastOffset = rewriter.create<vector::BroadcastOp>(
+            loc, baseConstVec.getType(), mulOffset);
+        auto finalConst =
+            arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
+        setLayoutIfNeeded(baseConstVec);
+        setLayoutIfNeeded(bcastOffset);
+        setLayoutIfNeeded(finalConst);
+        newConstOps.push_back(finalConst);
+      }
+      rewriter.replaceOpWithMultiple(op, {newConstOps});
+      return success();
+    }
   }
 };
 

diff  --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index dce73dee507e1..86a021b66949c 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -98,4 +98,31 @@ gpu.module @test_distribution {
       : vector<256x64xf32> to vector<256xf32>
     gpu.return
   }
+
+  gpu.func @non_splat_constant() {
+    // CHECK-DAG: %[[BASECST:.*]] = arith.constant dense<{{.*}}> : vector<2x1xindex>
+    // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[MAP4:.*]] = affine.apply #map4()[%[[SGID]]]
+    // CHECK-DAG: %[[MAP5:.*]] = affine.apply #map5()[%[[SGID]]]
+    // CHECK-DAG: %[[MUL:.*]] = index.mul %[[MAP4]], %[[C2:.*]]
+    // CHECK-DAG: %[[REMU1:.*]] = index.remu %[[MUL]], %[[C32:.*]]
+    // CHECK-DAG: %[[REMU2:.*]] = index.remu %[[MAP5]], %[[C1:.*]]
+    // CHECK-DAG: %[[ADD16:.*]] = arith.addi %[[MUL]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[REMU3:.*]] = index.remu %[[ADD16]], %[[C32:.*]]
+    // CHECK-DAG: %[[REMU4:.*]] = index.remu %[[MAP5]], %[[C1:.*]]
+    // CHECK-DAG: %[[STRIDE1:.*]] = arith.muli %[[REMU1]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[C0:.*]], %[[STRIDE1]] : index
+    // CHECK-DAG: %[[STRIDE2:.*]] = arith.muli %[[REMU2]], %[[C0:.*]] : index
+    // CHECK-DAG: %[[ADDSTRIDES1:.*]] = arith.addi %[[ADDSTRIDES]], %[[STRIDE2]] : index
+    // CHECK-DAG: %[[BCAST1:.*]] = vector.broadcast %[[ADDSTRIDES1]] : index to vector<2x1xindex>
+    // CHECK-DAG: %[[RESULT1:.*]] = arith.addi %[[BASECST]], %[[BCAST1]] : vector<2x1xindex>
+    // CHECK-DAG: %[[STRIDE3:.*]] = arith.muli %[[REMU3]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[ADDSTRIDES2:.*]] = arith.addi %[[C0:.*]], %[[STRIDE3]] : index
+    // CHECK-DAG: %[[STRIDE4:.*]] = arith.muli %[[REMU4]], %[[C0:.*]] : index
+    // CHECK-DAG: %[[ADDSTRIDES3:.*]] = arith.addi %[[ADDSTRIDES2]], %[[STRIDE4]] : index
+    // CHECK-DAG: %[[BCAST2:.*]] = vector.broadcast %[[ADDSTRIDES3]] : index to vector<2x1xindex>
+    // CHECK-DAG: %[[RESULT2:.*]] = arith.addi %[[BASECST]], %[[BCAST2]] : vector<2x1xindex>
+    %cst_2 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [2, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
+    gpu.return
+  }
 }

diff  --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 38392fd10b742..742d11f8052ec 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -463,4 +463,68 @@ gpu.module @test_distribution {
     %broadcast = vector.broadcast %muli {layout_result_0 = #xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>} : index to vector<4x2x6x32xindex>
     gpu.return
   }
+
+  // CHECK-LABEL: non_splat_constant_2D
+  gpu.func @non_splat_constant_2D() {
+    // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1x1xindex>
+    // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: affine.apply #map4()[%[[SGID]]]
+    // CHECK-DAG: affine.apply #map5()[%[[SGID]]]
+    // CHECK-DAG: %[[IDY:.*]] = index.remu %{{.*}}, %[[C32:.*]]
+    // CHECK-DAG: %[[IDX:.*]] = index.remu %{{.*}}, %[[C1:.*]]
+    // CHECK-DAG: %[[STRIDECOL:.*]] = arith.muli %[[IDY]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[C0:.*]], %[[STRIDECOL]] : index
+    // CHECK-DAG: %[[STRIDEROW:.*]] = arith.muli %[[IDX]], %[[C0:.*]] : index
+    // CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[ADD]], %[[STRIDEROW]] : index
+    // CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ADDSTRIDES]] : index to vector<1x1xindex>
+    // CHECK-DAG: arith.addi %[[CST]], %[[BCAST]] : vector<1x1xindex>
+    %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [1, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
+    gpu.return
+  }
+
+  // CHECK-LABEL: non_splat_constant_2D_non_unit_dim
+  gpu.func @non_splat_constant_2D_non_unit_dim() {
+    // CHECK-DAG: %[[BASECST:.*]] = arith.constant dense<{{.*}} : vector<2x2xindex>
+    // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[IDY:.*]] = affine.apply #map()[%[[SGID]]]
+    // CHECK-DAG: %[[IDX:.*]] = affine.apply #map1()[%[[SGID]]]
+    // CHECK-DAG: %[[MULY:.*]] = index.mul %[[IDY]], %[[C2:.*]]
+    // CHECK-DAG: %[[C2_2:.*]] = arith.constant 2 : index
+    // CHECK-DAG: %[[MULX:.*]] = index.mul %[[IDX]], %[[C2:.*]]
+    // CHECK-DAG: %[[REMU_Y:.*]] = index.remu %[[MULY]], %[[C8:.*]]
+    // CHECK-DAG: %[[C8_2:.*]] = arith.constant 8 : index
+    // CHECK-DAG: %[[REMU_X:.*]] = index.remu %[[MULX]], %[[C8:.*]]
+    // CHECK-DAG: %[[MUL5:.*]] = arith.muli %[[REMU_Y]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[C0:.*]], %[[MUL5]] : index
+    // CHECK-DAG: %[[MUL6:.*]] = arith.muli %[[REMU_X]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi  %[[ADD]], %[[MUL6]] : index
+    // CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ADDSTRIDES]] : index to vector<2x2xindex>
+    // CHECK-DAG: %[[ADDCST:.*]] = arith.addi %[[BASECST]], %[[BCAST]] : vector<2x2xindex>
+    %cst_8x8 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2]>} dense<[
+         [0, 16, 32, 48, 64, 80, 96, 112],
+         [8, 24, 40, 56, 72, 88, 104, 120],
+         [16, 32, 48, 64, 80, 96, 112, 128],
+         [24, 40, 56, 72, 88, 104, 120, 136],
+         [32, 48, 64, 80, 96, 112, 128, 144],
+         [40, 56, 72, 88, 104, 120, 136, 152],
+         [48, 64, 80, 96, 112, 128, 144, 160],
+         [56, 72, 88, 104, 120, 136, 152, 168]
+      ]> : vector<8x8xindex>
+      gpu.return
+  }
+
+  // CHECK-LABEL: non_splat_constant
+  gpu.func @non_splat_constant() {
+    // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+    // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[REMU:.*]] = index.remu %[[SGID]], %[[C32:.*]]
+    // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[C0:.*]], %[[MUL]] : index
+    // CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ADDSTRIDES]] : index to vector<1xindex>
+    // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[CST]], %[[BCAST]] : vector<1xindex>
+    %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [1]>} dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256, 272, 288, 304, 320, 336, 352, 368, 384, 400, 416, 432, 448, 464, 480, 496]> : vector<32xindex>
+    // CHECK: arith.constant dense<{{\[}}[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]{{\]}}> : vector<1x16xindex>
+    %cst_1 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [1, 16]>} dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]> : vector<1x16xindex>
+    gpu.return
+  }
 }


        


More information about the Mlir-commits mailing list