[Mlir-commits] [mlir] [MLIR][XeGPU] Distribute non-splat constant from wg to sg (PR #161416)
Igor Zamyatin
llvmlistbot at llvm.org
Mon Oct 6 14:40:12 PDT 2025
================
@@ -733,22 +733,172 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
int count;
std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
- // Current limitation: constant of vector with single value.
- // TODO: support more complex cases, e.g., vector with multiple values.
- Attribute singleVal = vecAttr.getSplatValue<Attribute>();
-
auto newType = VectorType::get(sgShape, vecType.getElementType());
- auto sgAttr = DenseElementsAttr::get(newType, singleVal);
- auto cstOp =
- arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
- if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
- !layout.getEffectiveInstDataAsInt().empty())
- xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
- layout.dropSgLayoutAndData());
- SmallVector<Value> newConsts(count, cstOp);
+ Location loc = op.getLoc();
+ auto eltType = vecType.getElementType();
- rewriter.replaceOpWithMultiple(op, {newConsts});
- return success();
+ auto setLayoutIfNeeded = [&](Value val) {
+ if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
+ !layout.getEffectiveInstDataAsInt().empty()) {
+ xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val),
+ layout.dropSgLayoutAndData());
+ }
+ };
+
+ if (vecAttr.isSplat()) {
+ // Splat: single value for all subgroups
+ Attribute singleVal = vecAttr.getSplatValue<Attribute>();
+ auto sgAttr = DenseElementsAttr::get(newType, singleVal);
+ auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
+ setLayoutIfNeeded(cstOp->getResult(0));
+ rewriter.replaceOp(op, cstOp);
+ return success();
+ } else if (sgShape == wgShape) { // if the entire vector is shared by all
+ // subgroups, don't distribute
+ auto newConstOp =
+ arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
+ setLayoutIfNeeded(newConstOp->getResult(0));
+ rewriter.replaceOp(op, newConstOp);
+ return success();
+ } else {
+ // Non-splat constant
+ // Only supports 1D & 2D
+ // TODO: support other cases that require SLM access
+ if (!eltType.isIndex())
+ return rewriter.notifyMatchFailure(
+ op, "Unsupported element type for non-splat constant op.");
+
+ if (wgShape.size() > 2)
+ return rewriter.notifyMatchFailure(
+ op, "Only 1D & 2D vector constant supported");
+
+ SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
+ int64_t stride = 0;
+ int64_t rowStride = 0, colStride = 0;
+ if (wgShape.size() == 1) {
+ // 1D case: single stride
+ if (values.size() > 1) {
+ stride = cast<IntegerAttr>(values[1]).getInt() -
+ cast<IntegerAttr>(values[0]).getInt();
+ for (size_t i = 2; i < values.size(); ++i) {
+ int64_t diff = cast<IntegerAttr>(values[i]).getInt() -
+ cast<IntegerAttr>(values[i - 1]).getInt();
+ if (diff != stride)
+ return rewriter.notifyMatchFailure(
+ op, "Non-constant stride in non-splat constant op.");
+ }
+ }
+ } else if (wgShape.size() == 2) {
+ // 2D case: row stride and column stride
+ int64_t rows = wgShape[0], cols = wgShape[1];
+ // Compute col stride (stride between elements in a column)
+ if (cols > 1) {
+ colStride = cast<IntegerAttr>(values[1]).getInt() -
+ cast<IntegerAttr>(values[0]).getInt();
+ for (int64_t r = 0; r < rows; ++r) {
+ for (int64_t c = 1; c < cols; ++c) {
+ int64_t idx = r * cols + c;
+ int64_t prevIdx = r * cols + (c - 1);
+ int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
+ cast<IntegerAttr>(values[prevIdx]).getInt();
+ if (diff != colStride)
+ return rewriter.notifyMatchFailure(
+ op, "Non-constant column stride in 2D constant op.");
+ }
+ }
+ }
+ // Compute row stride (stride between elements in a row)
+ if (rows > 1) {
+ rowStride = cast<IntegerAttr>(values[cols]).getInt() -
+ cast<IntegerAttr>(values[0]).getInt();
+ for (int64_t c = 0; c < cols; ++c) {
+ for (int64_t r = 1; r < rows; ++r) {
+ int64_t idx = r * cols + c;
+ int64_t prevIdx = (r - 1) * cols + c;
+ int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
+ cast<IntegerAttr>(values[prevIdx]).getInt();
+ if (diff != rowStride)
+ return rewriter.notifyMatchFailure(
+ op, "Non-constant row stride in 2D constant op.");
+ }
+ }
+ }
+ }
+
+ // Determine the shape of the base tile for each subgroup.
+ SmallVector<int64_t> baseTileShape;
----------------
Garra1980 wrote:
can you just use sgShape directly instead of new var?
https://github.com/llvm/llvm-project/pull/161416
More information about the Mlir-commits
mailing list