[Mlir-commits] [mlir] [MLIR][XeGPU] Add support for cross-subgroup reduction from wg to sg (PR #170936)
Artem Kroviakov
llvmlistbot at llvm.org
Tue Dec 16 08:39:47 PST 2025
================
@@ -1161,64 +1161,338 @@ struct WgToSgVectorShapeCastOp
}
};
-/// Pattern for lowering vector.multi_reduction op to subgroup level.
-/// Current limitation: the sg_layout in the reduced dimension being 1
-/// so that reduction is local to subgroup & no cross-subgroup communication is
-/// needed.
-/// TODO: Add cases to handle more general situations which require SLM access.
+/// This function converts multi-dimensional subgroup indices into a single
+/// linear offset. It's used to calculate memory offsets in SLM for
+/// cross-subgroup reduction coordination.
+///
+/// Parameters:
+/// - sgIds: Multi-dimensional subgroup indices (e.g., [sgId_x, sgId_y, sgId_z])
+/// - dims: Which dimensions to include in linearization (e.g., [0, 2] for x and
+/// z dims)
+/// - sgLayout: Subgroup layout sizes for each dimension (e.g., [4, 8, 2] means
+/// 4x8x2 subgroups)
+///
+/// It uses row-major linearization formula:
+/// offset = sum(sgIds[dim] * stride[dim])
+/// where stride[dim] = product of all sgLayout sizes in dimensions after
+/// 'dim'
+///
+/// Example:
+/// - sgLayout = [4, 8, 2], dims = [0, 2] (linearize x and z dimensions)
+/// - sgIds = [1, 3, 1] (subgroup at position x=1, y=3, z=1)
+/// - Calculation:
+/// * dim=0: stride=1, term = sgIds[0] * 1 = 1 * 1 = 1
+/// * dim=2: stride=sgLayout[0]=4, term = sgIds[2] * 4 = 1 * 4 = 4
+/// * linearizedOffset = 1 + 4 = 5
+///
+/// This gives us a unique linear index for each combination of subgroup
+/// positions in the specified dimensions, which is used for SLM row/column
+/// addressing.
+static Value linearizeSubgroupIndices(ConversionPatternRewriter &rewriter,
+ Location loc, ArrayRef<Value> sgIds,
+ ArrayRef<int64_t> dims,
+ ArrayRef<int64_t> sgLayout) {
+ Value linearizedOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ int64_t stride = 1;
+
+ for (int64_t dim : dims) {
+ Value dimVal = sgIds[dim];
+ Value strideVal = arith::ConstantIndexOp::create(rewriter, loc, stride);
+ Value term = arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
+ linearizedOffset =
+ arith::AddIOp::create(rewriter, loc, linearizedOffset, term);
+ stride *= sgLayout[dim];
+ }
+
+ return linearizedOffset;
+}
+
+// Helper function to create the appropriate binary operation based on reduction
+// kind
+static Value reductionOpKind(ConversionPatternRewriter &rewriter, Location loc,
+ vector::CombiningKind kind, Value lhs, Value rhs) {
+ Type elemType = getElementTypeOrSelf(lhs.getType());
+ bool isFloat = isa<FloatType>(elemType);
+
+ switch (kind) {
+ case vector::CombiningKind::ADD:
+ return isFloat ? arith::AddFOp::create(rewriter, loc, lhs, rhs).getResult()
+ : arith::AddIOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::MUL:
+ return isFloat ? arith::MulFOp::create(rewriter, loc, lhs, rhs).getResult()
+ : arith::MulIOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::MINSI:
+ return arith::MinSIOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::MINUI:
+ return arith::MinUIOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::MAXSI:
+ return arith::MaxSIOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::MAXUI:
+ return arith::MaxUIOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::AND:
+ return arith::AndIOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::OR:
+ return arith::OrIOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::XOR:
+ return arith::XOrIOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::MINNUMF:
+ return arith::MinNumFOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::MAXNUMF:
+ return arith::MaxNumFOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::MINIMUMF:
+ return arith::MinimumFOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::MAXIMUMF:
+ return arith::MaximumFOp::create(rewriter, loc, lhs, rhs).getResult();
+ }
+ llvm_unreachable("unsupported OpKind");
+}
+
+/// This pattern transforms vector.multi_dim_reduction operations from
+/// workgroup-level to subgroup-level execution with support for multiple
+/// reduction dimensions.
+///
+/// Steps include:
+/// 1. LOCAL REDUCTION :
+/// - Each subgroup performs local reduction on its data slice
+/// - Uses ZERO accumulator to avoid double-counting during cross-subgroup
----------------
akroviakov wrote:
Isn't this risky to use `0` for min/max reductions?
https://github.com/llvm/llvm-project/pull/170936
More information about the Mlir-commits
mailing list