[Mlir-commits] [mlir] [MLIR][XeGPU] Distribute non-splat constant from wg to sg (PR #161416)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 30 11:29:20 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir
Author: Nishant Patel (nbpatel)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/161416.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+109-15)
- (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir (+27)
- (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir (+20)
``````````diff
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 9413a9296b184..be03e6e050c43 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -720,7 +720,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
ConversionPatternRewriter &rewriter) const override {
auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
auto vecType = dyn_cast<VectorType>(op.getType());
- if (!vecAttr || !vecAttr.isSplat() || !vecType)
+ if (!vecAttr || !vecType)
return failure();
xegpu::DistributeLayoutAttr layout =
@@ -733,22 +733,116 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
int count;
std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
- // Current limitation: constant of vector with single value.
- // TODO: support more complex cases, e.g., vector with multiple values.
- Attribute singleVal = vecAttr.getSplatValue<Attribute>();
-
auto newType = VectorType::get(sgShape, vecType.getElementType());
- auto sgAttr = DenseElementsAttr::get(newType, singleVal);
- auto cstOp =
- arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
- if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
- !layout.getEffectiveInstDataAsInt().empty())
- xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
- layout.dropSgLayoutAndData());
- SmallVector<Value> newConsts(count, cstOp);
+ Location loc = op.getLoc();
+ auto eltType = vecType.getElementType();
- rewriter.replaceOpWithMultiple(op, {newConsts});
- return success();
+ auto setLayoutIfNeeded = [&](Value val) {
+ if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
+ !layout.getEffectiveInstDataAsInt().empty()) {
+ xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val),
+ layout.dropSgLayoutAndData());
+ }
+ };
+
+ if (vecAttr.isSplat()) {
+ // Splat: single value for all subgroups
+ Attribute singleVal = vecAttr.getSplatValue<Attribute>();
+ auto sgAttr = DenseElementsAttr::get(newType, singleVal);
+ auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
+ setLayoutIfNeeded(cstOp->getResult(0));
+ rewriter.replaceOp(op, cstOp);
+ return success();
+ } else if (sgShape == wgShape) { // if the entire vector is shared by all
+ // subgroups, don't distribute
+ auto newConstOp =
+ arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
+ setLayoutIfNeeded(newConstOp->getResult(0));
+ rewriter.replaceOp(op, newConstOp);
+ return success();
+ } else {
+ // Non-splat constant
+ // Only supports 1D & 2D (with one unit dim)
+ // TODO: support other cases that require SLM access
+ if (!eltType.isIndex())
+ return rewriter.notifyMatchFailure(
+ op, "Unsupported element type for non-splat constant op.");
+
+ SmallVector<int64_t> sgLayout = layout.getEffectiveSgLayoutAsInt();
+ if (wgShape.size() > 2)
+ return rewriter.notifyMatchFailure(
+ op, "Only 1D & 2D vector constant supported");
+
+ // allow 2D vector/distributions with one unit dim
+ auto hasTwoNonUnitDims = [](ArrayRef<int64_t> dims) {
+ return dims.size() == 2 && dims[0] != 1 && dims[1] != 1;
+ };
+ if (hasTwoNonUnitDims(wgShape) || hasTwoNonUnitDims(sgLayout))
+ return rewriter.notifyMatchFailure(
+ op, "2D vector/distribution only supported with 1 unit dim");
+
+ int64_t nonUnitDim = 0;
+ if (wgShape.size() == 2)
+ nonUnitDim = wgShape[0] != 1 ? 0 : 1;
+
+ SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
+ int64_t stride = 0;
+ if (values.size() > 1) {
+ stride = cast<IntegerAttr>(values[1]).getInt() -
+ cast<IntegerAttr>(values[0]).getInt();
+ for (size_t i = 2; i < values.size(); ++i) {
+ int64_t diff = cast<IntegerAttr>(values[i]).getInt() -
+ cast<IntegerAttr>(values[i - 1]).getInt();
+ if (diff != stride)
+ return rewriter.notifyMatchFailure(
+ op, "Non-constant stride in non-splat constant op.");
+ }
+ }
+
+ int sgData = 1;
+ if (sgShape.size() == 1) {
+ sgData = static_cast<int>(sgShape[0]);
+ } else if (sgShape.size() == 2) {
+ sgData = static_cast<int>(sgShape[0] != 1 ? sgShape[0] : sgShape[1]);
+ } else {
+ return rewriter.notifyMatchFailure(
+ op, "Only 1D or 2D vector constant supported");
+ }
+
+ // Create a constant for the base tile
+ SmallVector<Attribute> baseTileValues;
+ for (int i = 0; i < sgData; ++i)
+ baseTileValues.push_back(values[i]);
+ auto tileAttr = DenseElementsAttr::get(VectorType::get({sgData}, eltType),
+ baseTileValues);
+ auto baseConstVec = rewriter.create<arith::ConstantOp>(loc, tileAttr);
+
+ // Get subgroup id
+ Value sgId =
+ gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+
+ auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+ if (failed(sgOffsets))
+ return failure();
+
+ auto strideConst = rewriter.create<arith::ConstantIndexOp>(loc, stride);
+ SmallVector<Value> newConstOps;
+ for (auto offsets : *sgOffsets) {
+ // Multiply offset with stride, broadcast it and add to baseConstVec
+ Value mulOffset = rewriter.create<arith::MulIOp>(
+ loc, rewriter.getIndexType(), offsets[nonUnitDim], strideConst);
+ auto bcastOffset = rewriter.create<vector::SplatOp>(
+ loc, VectorType::get({sgData}, rewriter.getIndexType()), mulOffset);
+ auto finalConst =
+ arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
+ setLayoutIfNeeded(baseConstVec);
+ setLayoutIfNeeded(bcastOffset);
+ setLayoutIfNeeded(finalConst);
+ newConstOps.push_back(finalConst);
+ }
+ rewriter.replaceOpWithMultiple(op, {newConstOps});
+ return success();
+ }
}
};
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index dce73dee507e1..f3e2e41ae4b65 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -98,4 +98,31 @@ gpu.module @test_distribution {
: vector<256x64xf32> to vector<256xf32>
gpu.return
}
+
+ gpu.func @non_splat_constant() {
+ // CHECK-DAG: %[[CST:.*]] = arith.constant dense<[0, 16]> : vector<2xindex>
+ // CHECK-DAG: %[[SG_ID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[AFF1:.*]] = affine.apply #map4()[%[[SG_ID]]]
+ // CHECK-DAG: %[[AFF2:.*]] = affine.apply #map5()[%[[SG_ID]]]
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK-DAG: %[[MUL:.*]] = index.mul %[[AFF1]], %[[C2]]
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[MUL]], %[[C0]] : index
+ // CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[AFF2]], %[[C0_0]] : index
+ // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+ // CHECK-DAG: %[[REM:.*]] = index.remu %[[ADD1]], %[[C32]]
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+ // CHECK-DAG: %[[C16_0:.*]] = arith.constant 16 : index
+ // CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[MUL]], %[[C16]] : index
+ // CHECK-DAG: %[[REM2:.*]] = index.remu %[[ADD3]], %[[C32]]
+ // CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[REM]], %[[C16_0]] : index
+ // CHECK-DAG: %[[SPLAT:.*]] = vector.splat %[[MUL2]] : vector<2xindex>
+ // CHECK-DAG: %[[MUL3:.*]] = arith.muli %[[REM2]], %[[C16_0]] : index
+ // CHECK-DAG: %[[SPLAT2:.*]] = vector.splat %[[MUL3]] : vector<2xindex>
+ // CHECK-DAG: %[[ADD4:.*]] = arith.addi %[[CST]], %[[SPLAT2]] : vector<2xindex>
+ %cst_2 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [2, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
+ gpu.return
+ }
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 48fc633974e63..07b1e0f9ba8db 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -464,4 +464,24 @@ gpu.module @test_distribution {
%broadcast = vector.broadcast %muli {layout_result_0 = #xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>} : index to vector<4x2x6x32xindex>
gpu.return
}
+
+ // CHECK-LABEL: non_splat_constant
+ gpu.func @non_splat_constant() {
+ // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[IDY:.*]] = affine.apply #map4()[%[[SGID]]]
+ // CHECK-DAG: %[[IDX:.*]] = affine.apply #map5()[%[[SGID]]]
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[ADDY:.*]] = arith.addi %[[IDY]], %[[C0]] : index
+ // CHECK-DAG: %[[ADDX:.*]] = arith.addi %[[IDX]], %[[C0_0]] : index
+ // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+ // CHECK-DAG: %[[REMU_Y:.*]] = index.remu %[[ADDY]], %[[C32]]
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+ // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU_Y]], %[[C16]] : index
+ // CHECK-DAG: %[[SPLAT:.*]] = vector.splat %[[MUL]] : vector<1xindex>
+ %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [1, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
+ gpu.return
+ }
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/161416
More information about the Mlir-commits
mailing list