[Mlir-commits] [mlir] [MLIR][XeGPU] Add support for cross-subgroup reduction from wg to sg (PR #170936)

Artem Kroviakov llvmlistbot at llvm.org
Tue Dec 16 08:32:49 PST 2025


================
@@ -1161,64 +1161,338 @@ struct WgToSgVectorShapeCastOp
   }
 };
 
-/// Pattern for lowering vector.multi_reduction op to subgroup level.
-/// Current limitation: the sg_layout in the reduced dimension being 1
-/// so that reduction is local to subgroup & no cross-subgroup communication is
-/// needed.
-/// TODO: Add cases to handle more general situations which require SLM access.
+/// This function converts multi-dimensional subgroup indices into a single
+/// linear offset. It's used to calculate memory offsets in SLM for
+/// cross-subgroup reduction coordination.
+///
+/// Parameters:
+/// - sgIds: Multi-dimensional subgroup indices (e.g., [sgId_x, sgId_y, sgId_z])
+/// - dims: Which dimensions to include in linearization (e.g., [0, 2] for x and
+///   z dims)
+/// - sgLayout: Subgroup layout sizes for each dimension (e.g., [4, 8, 2] means
+///   4x8x2 subgroups)
+///
+/// It uses row-major linearization formula:
+///    offset = sum(sgIds[dim] * stride[dim])
+///    where stride[dim] = product of all sgLayout sizes in dimensions after
+///    'dim'
+///
+/// Example:
+/// - sgLayout = [4, 8, 2], dims = [0, 2] (linearize x and z dimensions)
+/// - sgIds = [1, 3, 1] (subgroup at position x=1, y=3, z=1)
+/// - Calculation:
+///   * dim=0: stride=1, term = sgIds[0] * 1 = 1 * 1 = 1
+///   * dim=2: stride=sgLayout[0]=4, term = sgIds[2] * 4 = 1 * 4 = 4
+///   * linearizedOffset = 1 + 4 = 5
+///
+/// This gives us a unique linear index for each combination of subgroup
+/// positions in the specified dimensions, which is used for SLM row/column
+/// addressing.
+static Value linearizeSubgroupIndices(ConversionPatternRewriter &rewriter,
+                                      Location loc, ArrayRef<Value> sgIds,
+                                      ArrayRef<int64_t> dims,
+                                      ArrayRef<int64_t> sgLayout) {
+  Value linearizedOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
+  int64_t stride = 1;
+
+  for (int64_t dim : dims) {
+    Value dimVal = sgIds[dim];
+    Value strideVal = arith::ConstantIndexOp::create(rewriter, loc, stride);
+    Value term = arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
+    linearizedOffset =
+        arith::AddIOp::create(rewriter, loc, linearizedOffset, term);
+    stride *= sgLayout[dim];
+  }
+
+  return linearizedOffset;
+}
+
+// Helper function to create the appropriate binary operation based on reduction
+// kind
+static Value reductionOpKind(ConversionPatternRewriter &rewriter, Location loc,
----------------
akroviakov wrote:

Isn't it the same as `vector::makeArithReduction`?

https://github.com/llvm/llvm-project/pull/170936


More information about the Mlir-commits mailing list