[Mlir-commits] [mlir] [MLIR][XeGPU] Distribute non-splat constant from wg to sg (PR #161416)
Nishant Patel
llvmlistbot at llvm.org
Mon Oct 6 17:52:52 PDT 2025
https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/161416
>From 671215094dde611017c0f6c16e9665935820d462 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 26 Sep 2025 17:42:27 +0000
Subject: [PATCH 1/9] Add support for non splatable constant
---
.../Transforms/XeGPUWgToSgDistribute.cpp | 123 +++++++++++++++---
1 file changed, 107 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 9413a9296b184..8705f4aca0dd1 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -711,7 +711,6 @@ struct UnrealizedConversionCastOpPattern
}
};
-// This pattern distributes arith.constant op into subgroup-level constants
struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
@@ -720,7 +719,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
ConversionPatternRewriter &rewriter) const override {
auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
auto vecType = dyn_cast<VectorType>(op.getType());
- if (!vecAttr || !vecAttr.isSplat() || !vecType)
+ if (!vecAttr || !vecType)
return failure();
xegpu::DistributeLayoutAttr layout =
@@ -733,22 +732,114 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
int count;
std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
- // Current limitation: constant of vector with single value.
- // TODO: support more complex cases, e.g., vector with multiple values.
- Attribute singleVal = vecAttr.getSplatValue<Attribute>();
-
auto newType = VectorType::get(sgShape, vecType.getElementType());
- auto sgAttr = DenseElementsAttr::get(newType, singleVal);
- auto cstOp =
- arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
- if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
- !layout.getEffectiveInstDataAsInt().empty())
- xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
- layout.dropSgLayoutAndData());
- SmallVector<Value> newConsts(count, cstOp);
+ Location loc = op.getLoc();
+ auto eltType = vecType.getElementType();
- rewriter.replaceOpWithMultiple(op, {newConsts});
- return success();
+ auto setLayoutIfNeeded = [&](Value val) {
+ if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
+ !layout.getEffectiveInstDataAsInt().empty()) {
+ xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val),
+ layout.dropSgLayoutAndData());
+ }
+ };
+
+ if (vecAttr.isSplat()) {
+ // Splat: single value for all subgroups
+ Attribute singleVal = vecAttr.getSplatValue<Attribute>();
+ auto sgAttr = DenseElementsAttr::get(newType, singleVal);
+ auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
+ setLayoutIfNeeded(cstOp->getResult(0));
+ rewriter.replaceOp(op, cstOp);
+ return success();
+ } else if (sgShape == wgShape) { // if the entire vector is shared by all
+ // threads...don't distribute
+ auto newConstOp =
+ arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
+ setLayoutIfNeeded(newConstOp->getResult(0));
+ rewriter.replaceOp(op, newConstOp);
+ return success();
+ } else {
+ // Non-splat constant: use baseValue/stride logic for runtime indexing,
+ // with wrap-around
+ if (wgShape.size() >= 2 && wgShape[0] != 1 && wgShape[1] != 1)
+ return rewriter.notifyMatchFailure(
+ op, "Only 1D or 2D vector constant supported");
+ SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
+ int64_t stride = 0;
+ if (values.size() > 1) {
+ stride = cast<IntegerAttr>(values[1]).getInt() -
+ cast<IntegerAttr>(values[0]).getInt();
+ for (size_t i = 2; i < values.size(); ++i) {
+ int64_t diff = cast<IntegerAttr>(values[i]).getInt() -
+ cast<IntegerAttr>(values[i - 1]).getInt();
+ if (diff != stride)
+ return rewriter.notifyMatchFailure(
+ op, "Non-constant stride in non-splat constant op.");
+ }
+ }
+
+ // Create a constant for the first tile
+ SmallVector<Attribute> tileValues;
+ int sgData = 1;
+ if (sgShape.size() == 1) {
+ sgData = static_cast<int>(sgShape[0]);
+ } else if (sgShape.size() == 2) {
+ // If shape is [1, n] or [n, 1], pick the non-1 dimension (n).
+ if (sgShape[0] == 1 && sgShape[1] != 1)
+ sgData = static_cast<int>(sgShape[1]);
+ else
+ sgData = static_cast<int>(sgShape[0]);
+ } else {
+ return rewriter.notifyMatchFailure(
+ op, "Only 1D or 2D vector constant supported");
+ }
+
+ for (int i = 0; i < sgData; ++i)
+ tileValues.push_back(values[i]);
+ auto tileAttr = DenseElementsAttr::get(VectorType::get({sgData}, eltType),
+ tileValues);
+ auto baseConstVec = rewriter.create<arith::ConstantOp>(loc, tileAttr);
+
+ // Get subgroup/thread id
+ Value sgId =
+ gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+
+ // Compute baseValue: baseValue = (sgId % numTiles) * stride * sgData
+ int64_t nonUnitDim = 0;
+ if (wgShape.size() == 2)
+ nonUnitDim = wgShape[0] != 1 ? 0 : 1;
+ // For 1D, just use the first dim
+ int64_t numTiles = wgShape[nonUnitDim] / sgShape[nonUnitDim];
+ auto numTileConst =
+ rewriter.create<arith::ConstantIndexOp>(loc, numTiles);
+ Value remsiOp = rewriter.create<arith::RemSIOp>(
+ loc, rewriter.getIndexType(), sgId, numTileConst);
+ auto baseValueConst =
+ rewriter.create<arith::ConstantIndexOp>(loc, stride * sgData);
+ Value baseValue = rewriter.create<arith::MulIOp>(
+ loc, rewriter.getIndexType(), remsiOp, baseValueConst);
+
+ // Broadcast baseValue to vector
+ auto splatBaseValue = rewriter.create<vector::SplatOp>(
+ loc, VectorType::get({sgData}, rewriter.getIndexType()), baseValue);
+
+ // Add baseValue to baseConstantVec constant
+ Value finalTile = rewriter.create<arith::AddIOp>(
+ loc, splatBaseValue->getResult(0), baseConstVec);
+
+ // Cast to final type if needed
+ Value result;
+ if (eltType.isIndex()) {
+ result = finalTile;
+ } else {
+ result = rewriter.create<arith::IndexCastOp>(loc, newType, finalTile);
+ }
+
+ setLayoutIfNeeded(result);
+ rewriter.replaceOp(op, result);
+ return success();
+ }
}
};
>From 7d3746a09bc77c997382cb137c3a8b2326d28c6f Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 30 Sep 2025 06:00:50 +0000
Subject: [PATCH 2/9] Support 1:N conversion
---
.../Transforms/XeGPUWgToSgDistribute.cpp | 72 +++++++++----------
.../XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir | 28 ++++++++
.../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 20 ++++++
3 files changed, 80 insertions(+), 40 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 8705f4aca0dd1..2bbf1a85bb5be 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -753,18 +753,22 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
rewriter.replaceOp(op, cstOp);
return success();
} else if (sgShape == wgShape) { // if the entire vector is shared by all
- // threads...don't distribute
+ // subgroups...don't distribute
auto newConstOp =
arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
setLayoutIfNeeded(newConstOp->getResult(0));
rewriter.replaceOp(op, newConstOp);
return success();
} else {
- // Non-splat constant: use baseValue/stride logic for runtime indexing,
- // with wrap-around
- if (wgShape.size() >= 2 && wgShape[0] != 1 && wgShape[1] != 1)
+ // Non-splat constant
+ if (wgShape.size() > 2)
return rewriter.notifyMatchFailure(
- op, "Only 1D or 2D vector constant supported");
+ op, "Only 1D & 2D vector constant supported");
+
+ if (wgShape.size() == 2 && wgShape[0] != 1 && wgShape[1] != 1)
+ return rewriter.notifyMatchFailure(
+ op, "2D vector constant only supported with 1 unit dim");
+
SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
int64_t stride = 0;
if (values.size() > 1) {
@@ -779,13 +783,13 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
}
}
- // Create a constant for the first tile
+ // Create a constant for the base tile
SmallVector<Attribute> tileValues;
int sgData = 1;
if (sgShape.size() == 1) {
sgData = static_cast<int>(sgShape[0]);
} else if (sgShape.size() == 2) {
- // If shape is [1, n] or [n, 1], pick the non-1 dimension (n).
+ // If shape is [1, n] or [n, 1], pick the non-unit dimension.
if (sgShape[0] == 1 && sgShape[1] != 1)
sgData = static_cast<int>(sgShape[1]);
else
@@ -801,43 +805,31 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
tileValues);
auto baseConstVec = rewriter.create<arith::ConstantOp>(loc, tileAttr);
- // Get subgroup/thread id
+ // Get subgroup id
Value sgId =
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
- // Compute baseValue: baseValue = (sgId % numTiles) * stride * sgData
- int64_t nonUnitDim = 0;
- if (wgShape.size() == 2)
- nonUnitDim = wgShape[0] != 1 ? 0 : 1;
- // For 1D, just use the first dim
- int64_t numTiles = wgShape[nonUnitDim] / sgShape[nonUnitDim];
- auto numTileConst =
- rewriter.create<arith::ConstantIndexOp>(loc, numTiles);
- Value remsiOp = rewriter.create<arith::RemSIOp>(
- loc, rewriter.getIndexType(), sgId, numTileConst);
- auto baseValueConst =
- rewriter.create<arith::ConstantIndexOp>(loc, stride * sgData);
- Value baseValue = rewriter.create<arith::MulIOp>(
- loc, rewriter.getIndexType(), remsiOp, baseValueConst);
-
- // Broadcast baseValue to vector
- auto splatBaseValue = rewriter.create<vector::SplatOp>(
- loc, VectorType::get({sgData}, rewriter.getIndexType()), baseValue);
-
- // Add baseValue to baseConstantVec constant
- Value finalTile = rewriter.create<arith::AddIOp>(
- loc, splatBaseValue->getResult(0), baseConstVec);
-
- // Cast to final type if needed
- Value result;
- if (eltType.isIndex()) {
- result = finalTile;
- } else {
- result = rewriter.create<arith::IndexCastOp>(loc, newType, finalTile);
+ auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+ if (failed(sgOffsets))
+ return failure();
+
+ SmallVector<Value> newConstOps;
+ for (auto offsets : *sgOffsets) {
+ // Multiply offset with stride and broadcast it to a vector of
+ // "sgData[nonUnitDim]" size
+ auto strideConst = rewriter.create<arith::ConstantIndexOp>(loc, stride);
+ Value mulOffset = rewriter.create<arith::MulIOp>(
+ loc, rewriter.getIndexType(), offsets[0], strideConst);
+ auto bcastOffset = rewriter.create<vector::SplatOp>(
+ loc, VectorType::get({sgData}, rewriter.getIndexType()), mulOffset);
+ auto finalConst =
+ arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
+ setLayoutIfNeeded(baseConstVec);
+ setLayoutIfNeeded(bcastOffset);
+ setLayoutIfNeeded(finalConst);
+ newConstOps.push_back(finalConst);
}
-
- setLayoutIfNeeded(result);
- rewriter.replaceOp(op, result);
+ rewriter.replaceOpWithMultiple(op, {newConstOps});
return success();
}
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index dce73dee507e1..271d2b2f908fb 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -98,4 +98,32 @@ gpu.module @test_distribution {
: vector<256x64xf32> to vector<256xf32>
gpu.return
}
+
+ gpu.func @non_splat_constant() {
+ // CHECK-DAG: %[[CST:.*]] = arith.constant dense<[0, 16]> : vector<2xindex>
+ // CHECK-DAG: %[[SG_ID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[AFF1:.*]] = affine.apply #map4()[%[[SG_ID]]]
+ // CHECK-DAG: %[[AFF2:.*]] = affine.apply #map5()[%[[SG_ID]]]
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK-DAG: %[[MUL:.*]] = index.mul %[[AFF1]], %[[C2]]
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[MUL]], %[[C0]] : index
+ // CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[AFF2]], %[[C0_0]] : index
+ // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+ // CHECK-DAG: %[[REM:.*]] = index.remu %[[ADD1]], %[[C32]]
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+ // CHECK-DAG: %[[C16_0:.*]] = arith.constant 16 : index
+ // CHECK-DAG: %[[C16_1:.*]] = arith.constant 16 : index
+ // CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[MUL]], %[[C16]] : index
+ // CHECK-DAG: %[[REM2:.*]] = index.remu %[[ADD3]], %[[C32]]
+ // CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[REM]], %[[C16_0]] : index
+ // CHECK-DAG: %[[SPLAT:.*]] = vector.splat %[[MUL2]] : vector<2xindex>
+ // CHECK-DAG: %[[MUL3:.*]] = arith.muli %[[REM2]], %[[C16_1]] : index
+ // CHECK-DAG: %[[SPLAT2:.*]] = vector.splat %[[MUL3]] : vector<2xindex>
+ // CHECK-DAG: %[[ADD4:.*]] = arith.addi %[[CST]], %[[SPLAT2]] : vector<2xindex>
+ %cst_2 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [2, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
+ gpu.return
+ }
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 48fc633974e63..07b1e0f9ba8db 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -464,4 +464,24 @@ gpu.module @test_distribution {
%broadcast = vector.broadcast %muli {layout_result_0 = #xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>} : index to vector<4x2x6x32xindex>
gpu.return
}
+
+ // CHECK-LABEL: non_splat_constant
+ gpu.func @non_splat_constant() {
+ // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[IDY:.*]] = affine.apply #map4()[%[[SGID]]]
+ // CHECK-DAG: %[[IDX:.*]] = affine.apply #map5()[%[[SGID]]]
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[ADDY:.*]] = arith.addi %[[IDY]], %[[C0]] : index
+ // CHECK-DAG: %[[ADDX:.*]] = arith.addi %[[IDX]], %[[C0_0]] : index
+ // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+ // CHECK-DAG: %[[REMU_Y:.*]] = index.remu %[[ADDY]], %[[C32]]
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+ // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU_Y]], %[[C16]] : index
+ // CHECK-DAG: %[[SPLAT:.*]] = vector.splat %[[MUL]] : vector<1xindex>
+ %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [1, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
+ gpu.return
+ }
}
>From 1b00dc76ebca1a475d1345387fd3edef0c34b659 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 30 Sep 2025 16:49:26 +0000
Subject: [PATCH 3/9] All cases work
---
.../Transforms/XeGPUWgToSgDistribute.cpp | 43 ++++++++++++-------
.../XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir | 3 +-
2 files changed, 28 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 2bbf1a85bb5be..be03e6e050c43 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -711,6 +711,7 @@ struct UnrealizedConversionCastOpPattern
}
};
+// This pattern distributes arith.constant op into subgroup-level constants
struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
@@ -753,7 +754,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
rewriter.replaceOp(op, cstOp);
return success();
} else if (sgShape == wgShape) { // if the entire vector is shared by all
- // subgroups...don't distribute
+ // subgroups, don't distribute
auto newConstOp =
arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
setLayoutIfNeeded(newConstOp->getResult(0));
@@ -761,13 +762,28 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
return success();
} else {
// Non-splat constant
+ // Only supports 1D & 2D (with one unit dim)
+ // TODO: support other cases that require SLM access
+ if (!eltType.isIndex())
+ return rewriter.notifyMatchFailure(
+ op, "Unsupported element type for non-splat constant op.");
+
+ SmallVector<int64_t> sgLayout = layout.getEffectiveSgLayoutAsInt();
if (wgShape.size() > 2)
return rewriter.notifyMatchFailure(
op, "Only 1D & 2D vector constant supported");
- if (wgShape.size() == 2 && wgShape[0] != 1 && wgShape[1] != 1)
+ // allow 2D vector/distributions with one unit dim
+ auto hasTwoNonUnitDims = [](ArrayRef<int64_t> dims) {
+ return dims.size() == 2 && dims[0] != 1 && dims[1] != 1;
+ };
+ if (hasTwoNonUnitDims(wgShape) || hasTwoNonUnitDims(sgLayout))
return rewriter.notifyMatchFailure(
- op, "2D vector constant only supported with 1 unit dim");
+ op, "2D vector/distribution only supported with 1 unit dim");
+
+ int64_t nonUnitDim = 0;
+ if (wgShape.size() == 2)
+ nonUnitDim = wgShape[0] != 1 ? 0 : 1;
SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
int64_t stride = 0;
@@ -783,26 +799,22 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
}
}
- // Create a constant for the base tile
- SmallVector<Attribute> tileValues;
int sgData = 1;
if (sgShape.size() == 1) {
sgData = static_cast<int>(sgShape[0]);
} else if (sgShape.size() == 2) {
- // If shape is [1, n] or [n, 1], pick the non-unit dimension.
- if (sgShape[0] == 1 && sgShape[1] != 1)
- sgData = static_cast<int>(sgShape[1]);
- else
- sgData = static_cast<int>(sgShape[0]);
+ sgData = static_cast<int>(sgShape[0] != 1 ? sgShape[0] : sgShape[1]);
} else {
return rewriter.notifyMatchFailure(
op, "Only 1D or 2D vector constant supported");
}
+ // Create a constant for the base tile
+ SmallVector<Attribute> baseTileValues;
for (int i = 0; i < sgData; ++i)
- tileValues.push_back(values[i]);
+ baseTileValues.push_back(values[i]);
auto tileAttr = DenseElementsAttr::get(VectorType::get({sgData}, eltType),
- tileValues);
+ baseTileValues);
auto baseConstVec = rewriter.create<arith::ConstantOp>(loc, tileAttr);
// Get subgroup id
@@ -813,13 +825,12 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
if (failed(sgOffsets))
return failure();
+ auto strideConst = rewriter.create<arith::ConstantIndexOp>(loc, stride);
SmallVector<Value> newConstOps;
for (auto offsets : *sgOffsets) {
- // Multiply offset with stride and broadcast it to a vector of
- // "sgData[nonUnitDim]" size
- auto strideConst = rewriter.create<arith::ConstantIndexOp>(loc, stride);
+ // Multiply offset with stride, broadcast it and add to baseConstVec
Value mulOffset = rewriter.create<arith::MulIOp>(
- loc, rewriter.getIndexType(), offsets[0], strideConst);
+ loc, rewriter.getIndexType(), offsets[nonUnitDim], strideConst);
auto bcastOffset = rewriter.create<vector::SplatOp>(
loc, VectorType::get({sgData}, rewriter.getIndexType()), mulOffset);
auto finalConst =
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index 271d2b2f908fb..f3e2e41ae4b65 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -115,12 +115,11 @@ gpu.module @test_distribution {
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
// CHECK-DAG: %[[C16_0:.*]] = arith.constant 16 : index
- // CHECK-DAG: %[[C16_1:.*]] = arith.constant 16 : index
// CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[MUL]], %[[C16]] : index
// CHECK-DAG: %[[REM2:.*]] = index.remu %[[ADD3]], %[[C32]]
// CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[REM]], %[[C16_0]] : index
// CHECK-DAG: %[[SPLAT:.*]] = vector.splat %[[MUL2]] : vector<2xindex>
- // CHECK-DAG: %[[MUL3:.*]] = arith.muli %[[REM2]], %[[C16_1]] : index
+ // CHECK-DAG: %[[MUL3:.*]] = arith.muli %[[REM2]], %[[C16_0]] : index
// CHECK-DAG: %[[SPLAT2:.*]] = vector.splat %[[MUL3]] : vector<2xindex>
// CHECK-DAG: %[[ADD4:.*]] = arith.addi %[[CST]], %[[SPLAT2]] : vector<2xindex>
%cst_2 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [2, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
>From 1381174b29812c6db58c46038aba2b718d9c9072 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 30 Sep 2025 20:33:23 +0000
Subject: [PATCH 4/9] Fix CHECKS
---
.../Transforms/XeGPUWgToSgDistribute.cpp | 2 +-
.../XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir | 34 ++++++++++++-------
.../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 13 ++++---
3 files changed, 28 insertions(+), 21 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index be03e6e050c43..9807cb98a5a83 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -831,7 +831,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
// Multiply offset with stride, broadcast it and add to baseConstVec
Value mulOffset = rewriter.create<arith::MulIOp>(
loc, rewriter.getIndexType(), offsets[nonUnitDim], strideConst);
- auto bcastOffset = rewriter.create<vector::SplatOp>(
+ auto bcastOffset = rewriter.create<vector::BroadcastOp>(
loc, VectorType::get({sgData}, rewriter.getIndexType()), mulOffset);
auto finalConst =
arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index f3e2e41ae4b65..9958d4ef6c1e2 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -102,26 +102,34 @@ gpu.module @test_distribution {
gpu.func @non_splat_constant() {
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<[0, 16]> : vector<2xindex>
// CHECK-DAG: %[[SG_ID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[C1_0:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[AFF1:.*]] = affine.apply #map4()[%[[SG_ID]]]
// CHECK-DAG: %[[AFF2:.*]] = affine.apply #map5()[%[[SG_ID]]]
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[MUL:.*]] = index.mul %[[AFF1]], %[[C2]]
+ // CHECK-DAG: %[[C1_1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[MUL]], %[[C0]] : index
- // CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[AFF2]], %[[C0_0]] : index
+ // CHECK-DAG: %[[C0_2:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
- // CHECK-DAG: %[[REM:.*]] = index.remu %[[ADD1]], %[[C32]]
- // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[REM:.*]] = index.remu %[[MUL]], %[[C32]]
+ // CHECK-DAG: %[[C1_3:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[REM2:.*]] = index.remu %[[AFF2]], %[[C1_3]]
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
- // CHECK-DAG: %[[C16_0:.*]] = arith.constant 16 : index
- // CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[MUL]], %[[C16]] : index
- // CHECK-DAG: %[[REM2:.*]] = index.remu %[[ADD3]], %[[C32]]
- // CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[REM]], %[[C16_0]] : index
- // CHECK-DAG: %[[SPLAT:.*]] = vector.splat %[[MUL2]] : vector<2xindex>
- // CHECK-DAG: %[[MUL3:.*]] = arith.muli %[[REM2]], %[[C16_0]] : index
- // CHECK-DAG: %[[SPLAT2:.*]] = vector.splat %[[MUL3]] : vector<2xindex>
- // CHECK-DAG: %[[ADD4:.*]] = arith.addi %[[CST]], %[[SPLAT2]] : vector<2xindex>
+ // CHECK-DAG: %[[C0_4:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[MUL]], %[[C16]] : index
+ // CHECK-DAG: %[[C32_5:.*]] = arith.constant 32 : index
+ // CHECK-DAG: %[[REM3:.*]] = index.remu %[[ADD]], %[[C32_5]]
+ // CHECK-DAG: %[[C1_6:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[REM4:.*]] = index.remu %[[AFF2]], %[[C1_6]]
+ // CHECK-DAG: %[[C16_7:.*]] = arith.constant 16 : index
+ // CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[REM]], %[[C16_7]] : index
+ // CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[MUL2]] : index to vector<2xindex>
+ // CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[CST]], %[[BCAST]] : vector<2xindex>
+ // CHECK-DAG: %[[MUL3:.*]] = arith.muli %[[REM3]], %[[C16_7]] : index
+ // CHECK-DAG: %[[BCAST2:.*]] = vector.broadcast %[[MUL3]] : index to vector<2xindex>
+ // CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[CST]], %[[BCAST2]] : vector<2xindex>
%cst_2 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [2, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
gpu.return
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 85b78bb41db08..a2203f8e945d2 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -463,18 +463,17 @@ gpu.module @test_distribution {
gpu.func @non_splat_constant() {
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[IDY:.*]] = affine.apply #map4()[%[[SGID]]]
// CHECK-DAG: %[[IDX:.*]] = affine.apply #map5()[%[[SGID]]]
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[ADDY:.*]] = arith.addi %[[IDY]], %[[C0]] : index
- // CHECK-DAG: %[[ADDX:.*]] = arith.addi %[[IDX]], %[[C0_0]] : index
- // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
- // CHECK-DAG: %[[REMU_Y:.*]] = index.remu %[[ADDY]], %[[C32]]
- // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[REMU_Y:.*]] = index.remu %[[IDY]], %[[C32]]
+ // CHECK-DAG: %[[REMU_X:.*]] = index.remu %[[IDX]], %[[C1]]
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
// CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU_Y]], %[[C16]] : index
- // CHECK-DAG: %[[SPLAT:.*]] = vector.splat %[[MUL]] : vector<1xindex>
+ // CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[MUL]] : index to vector<1xindex>
+ // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[CST]], %[[BCAST]] : vector<1xindex>
%cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [1, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
gpu.return
}
>From e77edddfc6ae346d0537f238c74b1b7524ec163e Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Sun, 5 Oct 2025 21:55:42 +0000
Subject: [PATCH 5/9] Support 2D case
---
.../Transforms/XeGPUWgToSgDistribute.cpp | 129 +++++++++++++-----
.../XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir | 50 +++----
.../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 22 ++-
3 files changed, 127 insertions(+), 74 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 9807cb98a5a83..b7107011ee178 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -773,48 +773,96 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
return rewriter.notifyMatchFailure(
op, "Only 1D & 2D vector constant supported");
- // allow 2D vector/distributions with one unit dim
- auto hasTwoNonUnitDims = [](ArrayRef<int64_t> dims) {
- return dims.size() == 2 && dims[0] != 1 && dims[1] != 1;
- };
- if (hasTwoNonUnitDims(wgShape) || hasTwoNonUnitDims(sgLayout))
- return rewriter.notifyMatchFailure(
- op, "2D vector/distribution only supported with 1 unit dim");
-
- int64_t nonUnitDim = 0;
- if (wgShape.size() == 2)
- nonUnitDim = wgShape[0] != 1 ? 0 : 1;
-
SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
int64_t stride = 0;
- if (values.size() > 1) {
- stride = cast<IntegerAttr>(values[1]).getInt() -
- cast<IntegerAttr>(values[0]).getInt();
- for (size_t i = 2; i < values.size(); ++i) {
- int64_t diff = cast<IntegerAttr>(values[i]).getInt() -
- cast<IntegerAttr>(values[i - 1]).getInt();
- if (diff != stride)
- return rewriter.notifyMatchFailure(
- op, "Non-constant stride in non-splat constant op.");
+ int64_t rowStride = 0, colStride = 0;
+ if (wgShape.size() == 1) {
+ // 1D case: single stride
+ if (values.size() > 1) {
+ stride = cast<IntegerAttr>(values[1]).getInt() -
+ cast<IntegerAttr>(values[0]).getInt();
+ for (size_t i = 2; i < values.size(); ++i) {
+ int64_t diff = cast<IntegerAttr>(values[i]).getInt() -
+ cast<IntegerAttr>(values[i - 1]).getInt();
+ if (diff != stride)
+ return rewriter.notifyMatchFailure(
+ op, "Non-constant stride in non-splat constant op.");
+ }
+ }
+ } else if (wgShape.size() == 2) {
+ // 2D case: row stride and column stride
+ int64_t rows = wgShape[0], cols = wgShape[1];
+ if (values.size() != static_cast<size_t>(rows * cols))
+ return rewriter.notifyMatchFailure(
+ op, "Mismatch between vector shape and constant values size.");
+ // Compute col stride (stride between elements in a column)
+ if (cols > 1) {
+ colStride = cast<IntegerAttr>(values[1]).getInt() -
+ cast<IntegerAttr>(values[0]).getInt();
+ for (int64_t r = 0; r < rows; ++r) {
+ for (int64_t c = 1; c < cols; ++c) {
+ int64_t idx = r * cols + c;
+ int64_t prevIdx = r * cols + (c - 1);
+ int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
+ cast<IntegerAttr>(values[prevIdx]).getInt();
+ if (diff != colStride)
+ return rewriter.notifyMatchFailure(
+ op, "Non-constant column stride in 2D constant op.");
+ }
+ }
+ }
+ // Compute row stride (stride between elements in a row)
+ if (rows > 1) {
+ rowStride = cast<IntegerAttr>(values[cols]).getInt() -
+ cast<IntegerAttr>(values[0]).getInt();
+ for (int64_t c = 0; c < cols; ++c) {
+ for (int64_t r = 1; r < rows; ++r) {
+ int64_t idx = r * cols + c;
+ int64_t prevIdx = (r - 1) * cols + c;
+ int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
+ cast<IntegerAttr>(values[prevIdx]).getInt();
+ if (diff != rowStride)
+ return rewriter.notifyMatchFailure(
+ op, "Non-constant row stride in 2D constant op.");
+ }
+ }
}
}
- int sgData = 1;
+ // Determine the shape of the base tile for each subgroup.
+ SmallVector<int64_t> baseTileShape;
if (sgShape.size() == 1) {
- sgData = static_cast<int>(sgShape[0]);
+ baseTileShape.push_back(sgShape[0]);
} else if (sgShape.size() == 2) {
- sgData = static_cast<int>(sgShape[0] != 1 ? sgShape[0] : sgShape[1]);
+ baseTileShape = sgShape;
} else {
return rewriter.notifyMatchFailure(
op, "Only 1D or 2D vector constant supported");
}
- // Create a constant for the base tile
+ // Compute the number of elements in the base tile.
+ int64_t baseTileElemCount = 1;
+ for (int64_t d : baseTileShape)
+ baseTileElemCount *= d;
+
+ // Create a constant for the base tile.
+ // For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix.
SmallVector<Attribute> baseTileValues;
- for (int i = 0; i < sgData; ++i)
- baseTileValues.push_back(values[i]);
- auto tileAttr = DenseElementsAttr::get(VectorType::get({sgData}, eltType),
- baseTileValues);
+ if (baseTileShape.size() == 2) {
+ int64_t rows = baseTileShape[0], cols = baseTileShape[1];
+ int64_t wgRows = wgShape[0], wgCols = wgShape[1];
+ for (int64_t r = 0; r < rows; ++r) {
+ for (int64_t c = 0; c < cols; ++c) {
+ baseTileValues.push_back(values[r * wgCols + c]);
+ }
+ }
+ } else {
+ // 1D case
+ for (int64_t i = 0; i < baseTileElemCount; ++i)
+ baseTileValues.push_back(values[i]);
+ }
+ auto tileAttr = DenseElementsAttr::get(
+ VectorType::get(baseTileShape, eltType), baseTileValues);
auto baseConstVec = rewriter.create<arith::ConstantOp>(loc, tileAttr);
// Get subgroup id
@@ -826,13 +874,30 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
return failure();
auto strideConst = rewriter.create<arith::ConstantIndexOp>(loc, stride);
+ auto strideConstRow =
+ rewriter.create<arith::ConstantIndexOp>(loc, rowStride);
+ auto strideConstCol =
+ rewriter.create<arith::ConstantIndexOp>(loc, colStride);
SmallVector<Value> newConstOps;
for (auto offsets : *sgOffsets) {
// Multiply offset with stride, broadcast it and add to baseConstVec
- Value mulOffset = rewriter.create<arith::MulIOp>(
- loc, rewriter.getIndexType(), offsets[nonUnitDim], strideConst);
+ Value mulOffset;
+ if (baseTileShape.size() == 1) {
+ // 1D: offset[0] * strideConst
+ mulOffset = rewriter.create<arith::MulIOp>(
+ loc, rewriter.getIndexType(), offsets[0], strideConst);
+ } else if (baseTileShape.size() == 2) {
+ // 2D: offset[0]*strideConstRow + offset[1]*strideConstCol
+ Value rowMul = rewriter.create<arith::MulIOp>(
+ loc, rewriter.getIndexType(), offsets[0], strideConstRow);
+ Value colMul = rewriter.create<arith::MulIOp>(
+ loc, rewriter.getIndexType(), offsets[1], strideConstCol);
+ mulOffset = rewriter.create<arith::AddIOp>(
+ loc, rewriter.getIndexType(), rowMul, colMul);
+ }
+ // Broadcast to baseConstVec size
auto bcastOffset = rewriter.create<vector::BroadcastOp>(
- loc, VectorType::get({sgData}, rewriter.getIndexType()), mulOffset);
+ loc, baseConstVec.getType(), mulOffset);
auto finalConst =
arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
setLayoutIfNeeded(baseConstVec);
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index 9958d4ef6c1e2..c2e51bdb71485 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -100,36 +100,26 @@ gpu.module @test_distribution {
}
gpu.func @non_splat_constant() {
- // CHECK-DAG: %[[CST:.*]] = arith.constant dense<[0, 16]> : vector<2xindex>
- // CHECK-DAG: %[[SG_ID:.*]] = gpu.subgroup_id : index
- // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
- // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
- // CHECK-DAG: %[[C1_0:.*]] = arith.constant 1 : index
- // CHECK-DAG: %[[AFF1:.*]] = affine.apply #map4()[%[[SG_ID]]]
- // CHECK-DAG: %[[AFF2:.*]] = affine.apply #map5()[%[[SG_ID]]]
- // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
- // CHECK-DAG: %[[MUL:.*]] = index.mul %[[AFF1]], %[[C2]]
- // CHECK-DAG: %[[C1_1:.*]] = arith.constant 1 : index
- // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[C0_2:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
- // CHECK-DAG: %[[REM:.*]] = index.remu %[[MUL]], %[[C32]]
- // CHECK-DAG: %[[C1_3:.*]] = arith.constant 1 : index
- // CHECK-DAG: %[[REM2:.*]] = index.remu %[[AFF2]], %[[C1_3]]
- // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
- // CHECK-DAG: %[[C0_4:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[MUL]], %[[C16]] : index
- // CHECK-DAG: %[[C32_5:.*]] = arith.constant 32 : index
- // CHECK-DAG: %[[REM3:.*]] = index.remu %[[ADD]], %[[C32_5]]
- // CHECK-DAG: %[[C1_6:.*]] = arith.constant 1 : index
- // CHECK-DAG: %[[REM4:.*]] = index.remu %[[AFF2]], %[[C1_6]]
- // CHECK-DAG: %[[C16_7:.*]] = arith.constant 16 : index
- // CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[REM]], %[[C16_7]] : index
- // CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[MUL2]] : index to vector<2xindex>
- // CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[CST]], %[[BCAST]] : vector<2xindex>
- // CHECK-DAG: %[[MUL3:.*]] = arith.muli %[[REM3]], %[[C16_7]] : index
- // CHECK-DAG: %[[BCAST2:.*]] = vector.broadcast %[[MUL3]] : index to vector<2xindex>
- // CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[CST]], %[[BCAST2]] : vector<2xindex>
+ // CHECK-DAG: %[[BASECST:.*]] = arith.constant dense<{{.*}}> : vector<2x1xindex>
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[MAP4:.*]] = affine.apply #map4()[%[[SGID]]]
+ // CHECK-DAG: %[[MAP5:.*]] = affine.apply #map5()[%[[SGID]]]
+ // CHECK-DAG: %[[MUL:.*]] = index.mul %[[MAP4]], %[[C2:.*]]
+ // CHECK-DAG: %[[REMU1:.*]] = index.remu %[[MUL]], %[[C32:.*]]
+ // CHECK-DAG: %[[REMU2:.*]] = index.remu %[[MAP5]], %[[C1:.*]]
+ // CHECK-DAG: %[[ADD16:.*]] = arith.addi %[[MUL]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[REMU3:.*]] = index.remu %[[ADD16]], %[[C32:.*]]
+ // CHECK-DAG: %[[REMU4:.*]] = index.remu %[[MAP5]], %[[C1:.*]]
+ // CHECK-DAG: %[[STRIDE1:.*]] = arith.muli %[[REMU1]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[STRIDE2:.*]] = arith.muli %[[REMU2]], %[[C0:.*]] : index
+ // CHECK-DAG: %[[ADDSTRIDES1:.*]] = arith.addi %[[STRIDE1]], %[[STRIDE2]] : index
+ // CHECK-DAG: %[[BCAST1:.*]] = vector.broadcast %[[ADDSTRIDES1]] : index to vector<2x1xindex>
+ // CHECK-DAG: %[[RESULT1:.*]] = arith.addi %[[BASECST]], %[[BCAST1]] : vector<2x1xindex>
+ // CHECK-DAG: %[[STRIDE3:.*]] = arith.muli %[[REMU3]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[STRIDE4:.*]] = arith.muli %[[REMU4]], %[[C0:.*]] : index
+ // CHECK-DAG: %[[ADDSTRIDES2:.*]] = arith.addi %[[STRIDE3]], %[[STRIDE4]] : index
+ // CHECK-DAG: %[[BCAST2:.*]] = vector.broadcast %[[ADDSTRIDES2]] : index to vector<2x1xindex>
+ // CHECK-DAG: %[[RESULT2:.*]] = arith.addi %[[BASECST]], %[[BCAST2]] : vector<2x1xindex>
%cst_2 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [2, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
gpu.return
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index a2203f8e945d2..51158fa11a9ec 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -461,19 +461,17 @@ gpu.module @test_distribution {
// CHECK-LABEL: non_splat_constant
gpu.func @non_splat_constant() {
- // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+ // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1x1xindex>
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
- // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
- // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
- // CHECK-DAG: %[[IDY:.*]] = affine.apply #map4()[%[[SGID]]]
- // CHECK-DAG: %[[IDX:.*]] = affine.apply #map5()[%[[SGID]]]
- // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[REMU_Y:.*]] = index.remu %[[IDY]], %[[C32]]
- // CHECK-DAG: %[[REMU_X:.*]] = index.remu %[[IDX]], %[[C1]]
- // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
- // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU_Y]], %[[C16]] : index
- // CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[MUL]] : index to vector<1xindex>
- // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[CST]], %[[BCAST]] : vector<1xindex>
+ // CHECK-DAG: affine.apply #map4()[%[[SGID]]]
+ // CHECK-DAG: affine.apply #map5()[%[[SGID]]]
+ // CHECK-DAG: %[[IDY:.*]] = index.remu %{{.*}}, %[[C32:.*]]
+ // CHECK-DAG: %[[IDX:.*]] = index.remu %{{.*}}, %[[C1:.*]]
+ // CHECK-DAG: %[[STRIDECOL:.*]] = arith.muli %[[IDY]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[STRIDEROW:.*]] = arith.muli %[[IDX]], %[[C0:.*]] : index
+ // CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[STRIDECOL]], %[[STRIDEROW]] : index
+ // CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ADDSTRIDES]] : index to vector<1x1xindex>
+ // CHECK-DAG: arith.addi %[[CST]], %[[BCAST]] : vector<1x1xindex>
%cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [1, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
gpu.return
}
>From 1b779b7d39cfb3650f0eec1dd6c5b7bace1dd4a9 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 6 Oct 2025 05:43:07 +0000
Subject: [PATCH 6/9] Add 2D test case
---
.../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 30 +++++++++++++++++++
1 file changed, 30 insertions(+)
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 51158fa11a9ec..5f990a49f1298 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -475,4 +475,34 @@ gpu.module @test_distribution {
%cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [1, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
gpu.return
}
+
+ // CHECK-LABEL: non_splat_constant_2D_non_unit_dim
+ gpu.func @non_splat_constant_2D_non_unit_dim() {
+ // CHECK-DAG: %[[BASECST:.*]] = arith.constant dense<{{.*}} : vector<2x2xindex>
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[IDY:.*]] = affine.apply #map()[%[[SGID]]]
+ // CHECK-DAG: %[[IDX:.*]] = affine.apply #map1()[%[[SGID]]]
+ // CHECK-DAG: %[[MULY:.*]] = index.mul %[[IDY]], %[[C2:.*]]
+ // CHECK-DAG: %[[C2_2:.*]] = arith.constant 2 : index
+ // CHECK-DAG: %[[MULX:.*]] = index.mul %[[IDX]], %[[C2:.*]]
+ // CHECK-DAG: %[[REMU_Y:.*]] = index.remu %[[MULY]], %[[C8:.*]]
+ // CHECK-DAG: %[[C8_2:.*]] = arith.constant 8 : index
+ // CHECK-DAG: %[[REMU_X:.*]] = index.remu %[[MULX]], %[[C8:.*]]
+ // CHECK-DAG: %[[MUL5:.*]] = arith.muli %[[REMU_Y]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[MUL6:.*]] = arith.muli %[[REMU_X]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[ADDIDX:.*]] = arith.addi %[[MUL5]], %[[MUL6]] : index
+ // CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ADDIDX]] : index to vector<2x2xindex>
+ // CHECK-DAG: %[[ADDCST:.*]] = arith.addi %[[BASECST]], %[[BCAST]] : vector<2x2xindex>
+ %cst_8x8 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2]>} dense<[
+ [0, 16, 32, 48, 64, 80, 96, 112],
+ [8, 24, 40, 56, 72, 88, 104, 120],
+ [16, 32, 48, 64, 80, 96, 112, 128],
+ [24, 40, 56, 72, 88, 104, 120, 136],
+ [32, 48, 64, 80, 96, 112, 128, 144],
+ [40, 56, 72, 88, 104, 120, 136, 152],
+ [48, 64, 80, 96, 112, 128, 144, 160],
+ [56, 72, 88, 104, 120, 136, 152, 168]
+ ]> : vector<8x8xindex>
+ gpu.return
+ }
}
>From 1b8db0e54a0a6d5d3e1258f3e9793ab2065e03fe Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 6 Oct 2025 05:47:56 +0000
Subject: [PATCH 7/9] Clean up
---
.../Transforms/XeGPUWgToSgDistribute.cpp | 29 +++++++------------
1 file changed, 10 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index b7107011ee178..2862400c85cca 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -762,13 +762,12 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
return success();
} else {
// Non-splat constant
- // Only supports 1D & 2D (with one unit dim)
+ // Only supports 1D & 2D
// TODO: support other cases that require SLM access
if (!eltType.isIndex())
return rewriter.notifyMatchFailure(
op, "Unsupported element type for non-splat constant op.");
- SmallVector<int64_t> sgLayout = layout.getEffectiveSgLayoutAsInt();
if (wgShape.size() > 2)
return rewriter.notifyMatchFailure(
op, "Only 1D & 2D vector constant supported");
@@ -792,9 +791,6 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
} else if (wgShape.size() == 2) {
// 2D case: row stride and column stride
int64_t rows = wgShape[0], cols = wgShape[1];
- if (values.size() != static_cast<size_t>(rows * cols))
- return rewriter.notifyMatchFailure(
- op, "Mismatch between vector shape and constant values size.");
// Compute col stride (stride between elements in a column)
if (cols > 1) {
colStride = cast<IntegerAttr>(values[1]).getInt() -
@@ -840,17 +836,12 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
op, "Only 1D or 2D vector constant supported");
}
- // Compute the number of elements in the base tile.
- int64_t baseTileElemCount = 1;
- for (int64_t d : baseTileShape)
- baseTileElemCount *= d;
-
// Create a constant for the base tile.
// For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix.
SmallVector<Attribute> baseTileValues;
if (baseTileShape.size() == 2) {
int64_t rows = baseTileShape[0], cols = baseTileShape[1];
- int64_t wgRows = wgShape[0], wgCols = wgShape[1];
+ int64_t wgCols = wgShape[1];
for (int64_t r = 0; r < rows; ++r) {
for (int64_t c = 0; c < cols; ++c) {
baseTileValues.push_back(values[r * wgCols + c]);
@@ -858,7 +849,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
}
} else {
// 1D case
- for (int64_t i = 0; i < baseTileElemCount; ++i)
+ for (int64_t i = 0; i < computeProduct(baseTileShape); ++i)
baseTileValues.push_back(values[i]);
}
auto tileAttr = DenseElementsAttr::get(
@@ -874,24 +865,24 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
return failure();
auto strideConst = rewriter.create<arith::ConstantIndexOp>(loc, stride);
- auto strideConstRow =
+ auto rowStrideConst =
rewriter.create<arith::ConstantIndexOp>(loc, rowStride);
- auto strideConstCol =
+ auto colStrideConst =
rewriter.create<arith::ConstantIndexOp>(loc, colStride);
SmallVector<Value> newConstOps;
for (auto offsets : *sgOffsets) {
// Multiply offset with stride, broadcast it and add to baseConstVec
Value mulOffset;
- if (baseTileShape.size() == 1) {
+ if (wgShape.size() == 1) {
// 1D: offset[0] * strideConst
mulOffset = rewriter.create<arith::MulIOp>(
loc, rewriter.getIndexType(), offsets[0], strideConst);
- } else if (baseTileShape.size() == 2) {
- // 2D: offset[0]*strideConstRow + offset[1]*strideConstCol
+ } else if (wgShape.size() == 2) {
+ // 2D: offset[0]*rowStrideConst + offset[1]*colStrideConst
Value rowMul = rewriter.create<arith::MulIOp>(
- loc, rewriter.getIndexType(), offsets[0], strideConstRow);
+ loc, rewriter.getIndexType(), offsets[0], rowStrideConst);
Value colMul = rewriter.create<arith::MulIOp>(
- loc, rewriter.getIndexType(), offsets[1], strideConstCol);
+ loc, rewriter.getIndexType(), offsets[1], colStrideConst);
mulOffset = rewriter.create<arith::AddIOp>(
loc, rewriter.getIndexType(), rowMul, colMul);
}
>From fabb41919b3ac7a24e2193e01c80f93d2933636e Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 6 Oct 2025 22:44:27 +0000
Subject: [PATCH 8/9] Clean up
---
.../Transforms/XeGPUWgToSgDistribute.cpp | 65 +++++++------------
1 file changed, 24 insertions(+), 41 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 2862400c85cca..dd9f50967534a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -825,35 +825,21 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
}
}
- // Determine the shape of the base tile for each subgroup.
- SmallVector<int64_t> baseTileShape;
- if (sgShape.size() == 1) {
- baseTileShape.push_back(sgShape[0]);
- } else if (sgShape.size() == 2) {
- baseTileShape = sgShape;
- } else {
- return rewriter.notifyMatchFailure(
- op, "Only 1D or 2D vector constant supported");
- }
-
// Create a constant for the base tile.
// For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix.
+ // For 1D case, extract the first sgShape[0] elements.
SmallVector<Attribute> baseTileValues;
- if (baseTileShape.size() == 2) {
- int64_t rows = baseTileShape[0], cols = baseTileShape[1];
- int64_t wgCols = wgShape[1];
- for (int64_t r = 0; r < rows; ++r) {
- for (int64_t c = 0; c < cols; ++c) {
- baseTileValues.push_back(values[r * wgCols + c]);
- }
+ int cols = sgShape[sgShape.size() - 1];
+ int64_t wgCols = wgShape[sgShape.size() - 1];
+ int64_t rows = sgShape.size() == 1 ? 1 : sgShape[0];
+ for (int64_t r = 0; r < rows; ++r) {
+ for (int64_t c = 0; c < cols; ++c) {
+ baseTileValues.push_back(values[r * wgCols + c]);
}
- } else {
- // 1D case
- for (int64_t i = 0; i < computeProduct(baseTileShape); ++i)
- baseTileValues.push_back(values[i]);
}
- auto tileAttr = DenseElementsAttr::get(
- VectorType::get(baseTileShape, eltType), baseTileValues);
+
+ auto tileAttr = DenseElementsAttr::get(VectorType::get(sgShape, eltType),
+ baseTileValues);
auto baseConstVec = rewriter.create<arith::ConstantOp>(loc, tileAttr);
// Get subgroup id
@@ -864,27 +850,24 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
if (failed(sgOffsets))
return failure();
- auto strideConst = rewriter.create<arith::ConstantIndexOp>(loc, stride);
- auto rowStrideConst =
- rewriter.create<arith::ConstantIndexOp>(loc, rowStride);
- auto colStrideConst =
- rewriter.create<arith::ConstantIndexOp>(loc, colStride);
+ SmallVector<Value, 2> strideConsts;
+ strideConsts.push_back(
+ rewriter.create<arith::ConstantIndexOp>(loc, rowStride));
+ strideConsts.push_back(
+ rewriter.create<arith::ConstantIndexOp>(loc, colStride));
SmallVector<Value> newConstOps;
+ Value mulOffset;
for (auto offsets : *sgOffsets) {
// Multiply offset with stride, broadcast it and add to baseConstVec
- Value mulOffset;
- if (wgShape.size() == 1) {
- // 1D: offset[0] * strideConst
- mulOffset = rewriter.create<arith::MulIOp>(
- loc, rewriter.getIndexType(), offsets[0], strideConst);
- } else if (wgShape.size() == 2) {
- // 2D: offset[0]*rowStrideConst + offset[1]*colStrideConst
- Value rowMul = rewriter.create<arith::MulIOp>(
- loc, rewriter.getIndexType(), offsets[0], rowStrideConst);
- Value colMul = rewriter.create<arith::MulIOp>(
- loc, rewriter.getIndexType(), offsets[1], colStrideConst);
+ SmallVector<Value> muls;
+ for (size_t i = 0; i < strideConsts.size(); ++i) {
+ muls.push_back(rewriter.create<arith::MulIOp>(
+ loc, rewriter.getIndexType(), offsets[i], strideConsts[i]));
+ }
+ mulOffset = muls.front();
+ if (muls.size() > 1) {
mulOffset = rewriter.create<arith::AddIOp>(
- loc, rewriter.getIndexType(), rowMul, colMul);
+ loc, rewriter.getIndexType(), mulOffset, muls[1]);
}
// Broadcast to baseConstVec size
auto bcastOffset = rewriter.create<vector::BroadcastOp>(
>From 2c81deeb011b1f3e4e7a97c731f715c0b4b6d9f8 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 7 Oct 2025 00:32:18 +0000
Subject: [PATCH 9/9] Refactor
---
.../Transforms/XeGPUWgToSgDistribute.cpp | 94 ++++++++-----------
.../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 4 +-
2 files changed, 43 insertions(+), 55 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index dd9f50967534a..659039b41638d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -773,54 +773,40 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
op, "Only 1D & 2D vector constant supported");
SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
- int64_t stride = 0;
int64_t rowStride = 0, colStride = 0;
- if (wgShape.size() == 1) {
- // 1D case: single stride
- if (values.size() > 1) {
- stride = cast<IntegerAttr>(values[1]).getInt() -
- cast<IntegerAttr>(values[0]).getInt();
- for (size_t i = 2; i < values.size(); ++i) {
- int64_t diff = cast<IntegerAttr>(values[i]).getInt() -
- cast<IntegerAttr>(values[i - 1]).getInt();
- if (diff != stride)
+ int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0];
+ int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1];
+
+ // Compute colStride and rowStride, and check for constant strides.
+ if (cols > 1) {
+ colStride = cast<IntegerAttr>(values[1]).getInt() -
+ cast<IntegerAttr>(values[0]).getInt();
+ }
+ if (rows > 1) {
+ rowStride = cast<IntegerAttr>(values[cols]).getInt() -
+ cast<IntegerAttr>(values[0]).getInt();
+ }
+
+ for (int64_t r = 0; r < rows; ++r) {
+ for (int64_t c = 0; c < cols; ++c) {
+ int64_t idx = r * cols + c;
+ // Check column stride (skip first column)
+ if (c > 0 && cols > 1) {
+ int64_t prevIdx = r * cols + (c - 1);
+ int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
+ cast<IntegerAttr>(values[prevIdx]).getInt();
+ if (diff != colStride)
return rewriter.notifyMatchFailure(
- op, "Non-constant stride in non-splat constant op.");
- }
- }
- } else if (wgShape.size() == 2) {
- // 2D case: row stride and column stride
- int64_t rows = wgShape[0], cols = wgShape[1];
- // Compute col stride (stride between elements in a column)
- if (cols > 1) {
- colStride = cast<IntegerAttr>(values[1]).getInt() -
- cast<IntegerAttr>(values[0]).getInt();
- for (int64_t r = 0; r < rows; ++r) {
- for (int64_t c = 1; c < cols; ++c) {
- int64_t idx = r * cols + c;
- int64_t prevIdx = r * cols + (c - 1);
- int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
- cast<IntegerAttr>(values[prevIdx]).getInt();
- if (diff != colStride)
- return rewriter.notifyMatchFailure(
- op, "Non-constant column stride in 2D constant op.");
- }
+ op, "Non-constant column stride in constant op.");
}
- }
- // Compute row stride (stride between elements in a row)
- if (rows > 1) {
- rowStride = cast<IntegerAttr>(values[cols]).getInt() -
- cast<IntegerAttr>(values[0]).getInt();
- for (int64_t c = 0; c < cols; ++c) {
- for (int64_t r = 1; r < rows; ++r) {
- int64_t idx = r * cols + c;
- int64_t prevIdx = (r - 1) * cols + c;
- int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
- cast<IntegerAttr>(values[prevIdx]).getInt();
- if (diff != rowStride)
- return rewriter.notifyMatchFailure(
- op, "Non-constant row stride in 2D constant op.");
- }
+ // Check row stride (skip first row)
+ if (r > 0 && rows > 1) {
+ int64_t prevIdx = (r - 1) * cols + c;
+ int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
+ cast<IntegerAttr>(values[prevIdx]).getInt();
+ if (diff != rowStride)
+ return rewriter.notifyMatchFailure(
+ op, "Non-constant row stride in constant op.");
}
}
}
@@ -829,12 +815,11 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
// For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix.
// For 1D case, extract the first sgShape[0] elements.
SmallVector<Attribute> baseTileValues;
- int cols = sgShape[sgShape.size() - 1];
- int64_t wgCols = wgShape[sgShape.size() - 1];
- int64_t rows = sgShape.size() == 1 ? 1 : sgShape[0];
- for (int64_t r = 0; r < rows; ++r) {
- for (int64_t c = 0; c < cols; ++c) {
- baseTileValues.push_back(values[r * wgCols + c]);
+ int baseTileCols = sgShape[sgShape.size() - 1];
+ int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0];
+ for (int64_t r = 0; r < baseTileRows; ++r) {
+ for (int64_t c = 0; c < baseTileCols; ++c) {
+ baseTileValues.push_back(values[r * cols + c]);
}
}
@@ -851,10 +836,13 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
return failure();
SmallVector<Value, 2> strideConsts;
- strideConsts.push_back(
- rewriter.create<arith::ConstantIndexOp>(loc, rowStride));
strideConsts.push_back(
rewriter.create<arith::ConstantIndexOp>(loc, colStride));
+ if (rows > 1)
+ strideConsts.insert(
+ strideConsts.begin(),
+ rewriter.create<arith::ConstantIndexOp>(loc, rowStride));
+
SmallVector<Value> newConstOps;
Value mulOffset;
for (auto offsets : *sgOffsets) {
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 5f990a49f1298..676c96db69236 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -459,8 +459,8 @@ gpu.module @test_distribution {
gpu.return
}
- // CHECK-LABEL: non_splat_constant
- gpu.func @non_splat_constant() {
+ // CHECK-LABEL: non_splat_constant_2D
+ gpu.func @non_splat_constant_2D() {
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1x1xindex>
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
// CHECK-DAG: affine.apply #map4()[%[[SGID]]]
More information about the Mlir-commits
mailing list