[Mlir-commits] [mlir] [MLIR][XeGPU] Add support for cross-subgroup reduction from wg to sg (PR #170936)

Artem Kroviakov llvmlistbot at llvm.org
Tue Dec 16 08:39:47 PST 2025


================
@@ -1161,64 +1161,338 @@ struct WgToSgVectorShapeCastOp
   }
 };
 
-/// Pattern for lowering vector.multi_reduction op to subgroup level.
-/// Current limitation: the sg_layout in the reduced dimension being 1
-/// so that reduction is local to subgroup & no cross-subgroup communication is
-/// needed.
-/// TODO: Add cases to handle more general situations which require SLM access.
+/// This function converts multi-dimensional subgroup indices into a single
+/// linear offset. It's used to calculate memory offsets in SLM for
+/// cross-subgroup reduction coordination.
+///
+/// Parameters:
+/// - sgIds: Multi-dimensional subgroup indices (e.g., [sgId_x, sgId_y, sgId_z])
+/// - dims: Which dimensions to include in linearization (e.g., [0, 2] for x and
+///   z dims)
+/// - sgLayout: Subgroup layout sizes for each dimension (e.g., [4, 8, 2] means
+///   4x8x2 subgroups)
+///
+/// It uses row-major linearization formula:
+///    offset = sum(sgIds[dim] * stride[dim])
+///    where stride[dim] = product of all sgLayout sizes in dimensions after
+///    'dim'
+///
+/// Example:
+/// - sgLayout = [4, 8, 2], dims = [0, 2] (linearize x and z dimensions)
+/// - sgIds = [1, 3, 1] (subgroup at position x=1, y=3, z=1)
+/// - Calculation:
+///   * dim=0: stride=1, term = sgIds[0] * 1 = 1 * 1 = 1
+///   * dim=2: stride=sgLayout[0]=4, term = sgIds[2] * 4 = 1 * 4 = 4
+///   * linearizedOffset = 1 + 4 = 5
+///
+/// This gives us a unique linear index for each combination of subgroup
+/// positions in the specified dimensions, which is used for SLM row/column
+/// addressing.
+static Value linearizeSubgroupIndices(ConversionPatternRewriter &rewriter,
+                                      Location loc, ArrayRef<Value> sgIds,
+                                      ArrayRef<int64_t> dims,
+                                      ArrayRef<int64_t> sgLayout) {
+  Value linearizedOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
+  int64_t stride = 1;
+
+  for (int64_t dim : dims) {
+    Value dimVal = sgIds[dim];
+    Value strideVal = arith::ConstantIndexOp::create(rewriter, loc, stride);
+    Value term = arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
+    linearizedOffset =
+        arith::AddIOp::create(rewriter, loc, linearizedOffset, term);
+    stride *= sgLayout[dim];
+  }
+
+  return linearizedOffset;
+}
+
+// Helper function to create the appropriate binary operation based on reduction
+// kind
+static Value reductionOpKind(ConversionPatternRewriter &rewriter, Location loc,
+                             vector::CombiningKind kind, Value lhs, Value rhs) {
+  Type elemType = getElementTypeOrSelf(lhs.getType());
+  bool isFloat = isa<FloatType>(elemType);
+
+  switch (kind) {
+  case vector::CombiningKind::ADD:
+    return isFloat ? arith::AddFOp::create(rewriter, loc, lhs, rhs).getResult()
+                   : arith::AddIOp::create(rewriter, loc, lhs, rhs).getResult();
+  case vector::CombiningKind::MUL:
+    return isFloat ? arith::MulFOp::create(rewriter, loc, lhs, rhs).getResult()
+                   : arith::MulIOp::create(rewriter, loc, lhs, rhs).getResult();
+  case vector::CombiningKind::MINSI:
+    return arith::MinSIOp::create(rewriter, loc, lhs, rhs).getResult();
+  case vector::CombiningKind::MINUI:
+    return arith::MinUIOp::create(rewriter, loc, lhs, rhs).getResult();
+  case vector::CombiningKind::MAXSI:
+    return arith::MaxSIOp::create(rewriter, loc, lhs, rhs).getResult();
+  case vector::CombiningKind::MAXUI:
+    return arith::MaxUIOp::create(rewriter, loc, lhs, rhs).getResult();
+  case vector::CombiningKind::AND:
+    return arith::AndIOp::create(rewriter, loc, lhs, rhs).getResult();
+  case vector::CombiningKind::OR:
+    return arith::OrIOp::create(rewriter, loc, lhs, rhs).getResult();
+  case vector::CombiningKind::XOR:
+    return arith::XOrIOp::create(rewriter, loc, lhs, rhs).getResult();
+  case vector::CombiningKind::MINNUMF:
+    return arith::MinNumFOp::create(rewriter, loc, lhs, rhs).getResult();
+  case vector::CombiningKind::MAXNUMF:
+    return arith::MaxNumFOp::create(rewriter, loc, lhs, rhs).getResult();
+  case vector::CombiningKind::MINIMUMF:
+    return arith::MinimumFOp::create(rewriter, loc, lhs, rhs).getResult();
+  case vector::CombiningKind::MAXIMUMF:
+    return arith::MaximumFOp::create(rewriter, loc, lhs, rhs).getResult();
+  }
+  llvm_unreachable("unsupported OpKind");
+}
+
+/// This pattern transforms vector.multi_dim_reduction operations from
+/// workgroup-level to subgroup-level execution with support for multiple
+/// reduction dimensions.
+///
+/// Steps include:
+/// 1. LOCAL REDUCTION :
+///    - Each subgroup performs local reduction on its data slice
+///    - Uses ZERO accumulator to avoid double-counting during cross-subgroup
----------------
akroviakov wrote:

Isn't this risky to use `0` for min/max reductions?

https://github.com/llvm/llvm-project/pull/170936


More information about the Mlir-commits mailing list