[Mlir-commits] [mlir] [MLIR][XeGPU] Support leading unit dim for reduction in sg to wi pass (PR #185110)
Jianhui Li
llvmlistbot at llvm.org
Fri Mar 6 13:59:54 PST 2026
================
@@ -1321,17 +1321,29 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
unsigned operandIdx = yieldOperand->getOperandNumber();
VectorType sourceType = reductionOp.getSourceVectorType();
- // Only 2D vectors are supported.
- if (sourceType.getRank() != 2)
+ int64_t sourceRank = sourceType.getRank();
+ // Need at least a 2D source vector.
+ if (sourceRank < 2)
return rewriter.notifyMatchFailure(warpOp,
- "Only 2D reductions are supported.");
+ "Only 2D+ reductions are supported.");
+ // Leading dimensions (first rank-2) must be unit (size 1).
+ for (int64_t i = 0; i < sourceRank - 2; ++i) {
+ if (sourceType.getShape()[i] != 1)
+ return rewriter.notifyMatchFailure(
+ warpOp, "Only unit dimensions allowed for the leading dimensions.");
+ }
+ // Effective dimension indices (last 2 dims of the source).
+ int64_t dim0Idx = sourceRank - 2;
----------------
Jianhui-Li wrote:
I think dim0 and dim1 a bit confusing.
how about: dim0Idx -> rowIdx, dim1Idx->columnIdx
The same naming issue inside lowerToVectorReductions
https://github.com/llvm/llvm-project/pull/185110
More information about the Mlir-commits
mailing list