[Mlir-commits] [mlir] [MLIR][XeGPU] Support leading unit dim for reduction in sg to wi pass (PR #185110)

Jianhui Li llvmlistbot at llvm.org
Fri Mar 6 13:59:54 PST 2026


================
@@ -1321,17 +1321,29 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
         cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
     unsigned operandIdx = yieldOperand->getOperandNumber();
     VectorType sourceType = reductionOp.getSourceVectorType();
-    // Only 2D vectors are supported.
-    if (sourceType.getRank() != 2)
+    int64_t sourceRank = sourceType.getRank();
+    // Need at least a 2D source vector.
+    if (sourceRank < 2)
       return rewriter.notifyMatchFailure(warpOp,
-                                         "Only 2D reductions are supported.");
+                                         "Only 2D+ reductions are supported.");
+    // Leading dimensions (first rank-2) must be unit (size 1).
+    for (int64_t i = 0; i < sourceRank - 2; ++i) {
+      if (sourceType.getShape()[i] != 1)
+        return rewriter.notifyMatchFailure(
+            warpOp, "Only unit dimensions allowed for the leading dimensions.");
+    }
+    // Effective dimension indices (last 2 dims of the source).
+    int64_t dim0Idx = sourceRank - 2;
----------------
Jianhui-Li wrote:

I think dim0 and dim1 a bit confusing. 
how about:  dim0Idx -> rowIdx, dim1Idx->columnIdx 

The same naming issue inside lowerToVectorReductions

https://github.com/llvm/llvm-project/pull/185110


More information about the Mlir-commits mailing list