[Mlir-commits] [mlir] [MLIR][XeGPU] Distribute non-splat constant from wg to sg (PR #161416)
Igor Zamyatin
llvmlistbot at llvm.org
Mon Oct 6 14:30:44 PDT 2025
================
@@ -733,22 +733,116 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
int count;
std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
- // Current limitation: constant of vector with single value.
- // TODO: support more complex cases, e.g., vector with multiple values.
- Attribute singleVal = vecAttr.getSplatValue<Attribute>();
-
auto newType = VectorType::get(sgShape, vecType.getElementType());
- auto sgAttr = DenseElementsAttr::get(newType, singleVal);
- auto cstOp =
- arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
- if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
- !layout.getEffectiveInstDataAsInt().empty())
- xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
- layout.dropSgLayoutAndData());
- SmallVector<Value> newConsts(count, cstOp);
+ Location loc = op.getLoc();
+ auto eltType = vecType.getElementType();
- rewriter.replaceOpWithMultiple(op, {newConsts});
- return success();
+ auto setLayoutIfNeeded = [&](Value val) {
+ if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
+ !layout.getEffectiveInstDataAsInt().empty()) {
+ xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val),
+ layout.dropSgLayoutAndData());
+ }
+ };
+
+ if (vecAttr.isSplat()) {
+ // Splat: single value for all subgroups
+ Attribute singleVal = vecAttr.getSplatValue<Attribute>();
+ auto sgAttr = DenseElementsAttr::get(newType, singleVal);
+ auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
+ setLayoutIfNeeded(cstOp->getResult(0));
+ rewriter.replaceOp(op, cstOp);
+ return success();
+ } else if (sgShape == wgShape) { // if the entire vector is shared by all
----------------
Garra1980 wrote:
is there a test for this branch?
https://github.com/llvm/llvm-project/pull/161416
More information about the Mlir-commits
mailing list