[Mlir-commits] [mlir] [mlir][AMDGPU] Implement gpu.subgroup_reduce with DPP intrinsics on AMD GPUs (PR #133204)
Jakub Kuderski
llvmlistbot at llvm.org
Wed Apr 23 08:01:35 PDT 2025
================
@@ -362,6 +366,163 @@ struct VectorSubgroupReduceToShuffles final
unsigned shuffleBitwidth = 0;
bool matchClustered = false;
};
+
+FailureOr<Value>
+createSubgroupDPPReduction(PatternRewriter &rewriter, gpu::SubgroupReduceOp op,
+ Value input, gpu::AllReduceOperation mode,
+ const ClusterInfo &ci, amdgpu::Chipset chipset) {
+ Location loc = op.getLoc();
+ Value dpp;
+ Value res = input;
+ constexpr int allRows = 0xf;
+ constexpr int allBanks = 0xf;
+ const bool boundCtrl = true;
+ if (ci.clusterSize >= 2) {
+ // Perform reduction between all lanes N <-> N+1.
+ dpp = rewriter.create<amdgpu::DPPOp>(
+ loc, res.getType(), res, res, amdgpu::DPPPerm::quad_perm,
+ rewriter.getI32ArrayAttr({1, 0, 3, 2}), allRows, allBanks, boundCtrl);
+ res = vector::makeArithReduction(rewriter, loc,
+ gpu::convertReductionKind(mode), res, dpp);
+ }
+
+ if (ci.clusterSize >= 4) {
+ // Perform reduction between all lanes N <-> N+2.
+ dpp = rewriter.create<amdgpu::DPPOp>(
+ loc, res.getType(), res, res, amdgpu::DPPPerm::quad_perm,
+ rewriter.getI32ArrayAttr({2, 3, 0, 1}), allRows, allBanks, boundCtrl);
+ res = vector::makeArithReduction(rewriter, loc,
+ gpu::convertReductionKind(mode), res, dpp);
+ }
+ if (ci.clusterSize >= 8) {
+ // Perform reduction between all lanes N <-> 7-N,
+ // e.g lane[0] <-> lane[7], lane[1] <-> lane[6]..., lane[3] <-> lane[4].
+ dpp = rewriter.create<amdgpu::DPPOp>(
+ loc, res.getType(), res, res, amdgpu::DPPPerm::row_half_mirror,
+ rewriter.getUnitAttr(), allRows, allBanks, boundCtrl);
+ res = vector::makeArithReduction(rewriter, loc,
+ gpu::convertReductionKind(mode), res, dpp);
+ }
+ if (ci.clusterSize >= 16) {
+ // Perform reduction between all lanes N <-> 15-N,
+ // e.g lane[0] <-> lane[15], lane[1] <-> lane[14]..., lane[7] <-> lane[8].
+ dpp = rewriter.create<amdgpu::DPPOp>(
+ loc, res.getType(), res, res, amdgpu::DPPPerm::row_mirror,
+ rewriter.getUnitAttr(), allRows, allBanks, boundCtrl);
+ res = vector::makeArithReduction(rewriter, loc,
+ gpu::convertReductionKind(mode), res, dpp);
+ }
+ if (ci.clusterSize >= 32) {
+ if (chipset.majorVersion <= 9) {
+ // Broadcast last value from each row to next row.
+ // Use row mask to avoid polluting rows 1 and 3.
+ dpp = rewriter.create<amdgpu::DPPOp>(
+ loc, res.getType(), res, res, amdgpu::DPPPerm::row_bcast_15,
+ rewriter.getUnitAttr(), 0xa, allBanks,
+ /*bound_ctrl*/ false);
+ res = vector::makeArithReduction(
+ rewriter, loc, gpu::convertReductionKind(mode), res, dpp);
+ } else if (chipset.majorVersion <= 12) {
+ // Use a permute lane to cross rows (row 1 <-> row 0, row 3 <-> row 2).
+ Value uint32Max = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(-1));
+ dpp = rewriter.create<ROCDL::PermlaneX16Op>(loc, res.getType(), res, res,
+ uint32Max, uint32Max,
+ /*fi=*/true,
+ /*bound_ctrl=*/false);
+ res = vector::makeArithReduction(
+ rewriter, loc, gpu::convertReductionKind(mode), res, dpp);
+ if (ci.subgroupSize == 32) {
+ Value lane0 = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
+ res =
+ rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane0);
+ }
+ } else {
+ return rewriter.notifyMatchFailure(
+ op, "Subgroup reduce lowering to DPP not currently supported for "
+ "this device.");
+ }
+ }
+ if (ci.clusterSize >= 64) {
+ if (chipset.majorVersion <= 9) {
+ // Broadcast 31st lane value to rows 2 and 3.
+ // Use row mask to avoid polluting rows 0 and 1.
+ dpp = rewriter.create<amdgpu::DPPOp>(
+ loc, res.getType(), res, res, amdgpu::DPPPerm::row_bcast_31,
+ rewriter.getUnitAttr(), 0xc, allBanks,
+ /*bound_ctrl*/ false);
+
+ } else if (chipset.majorVersion <= 12) {
+ // Assume reduction across 32 lanes has been done.
+ // Perform final reduction manually by summing values in lane 0 and
+ // lane 32.
+ Value lane0 = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
+ Value lane32 = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(32));
+ dpp = rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane32);
+ res = rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane0);
+ } else {
+ return rewriter.notifyMatchFailure(
+ op, "Subgroup reduce lowering to DPP not currently supported for "
+ "this device.");
+ }
+ res = vector::makeArithReduction(rewriter, loc,
+ gpu::convertReductionKind(mode), res, dpp);
+ }
+ assert(res.getType() == input.getType());
+ return res;
+}
+
+/// Collect a set of patterns to lower `gpu.subgroup_reduce` into `amdgpu.dpp`
+/// ops over scalar types. Assumes that the subgroup has
+/// `subgroupSize` lanes. Applicable only to AMD GPUs.
+struct ScalarSubgroupReduceToDPP final
+ : OpRewritePattern<gpu::SubgroupReduceOp> {
+ ScalarSubgroupReduceToDPP(MLIRContext *ctx, unsigned subgroupSize,
+ bool matchClustered, amdgpu::Chipset chipset,
+ PatternBenefit benefit)
+ : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
+ matchClustered(matchClustered), chipset(chipset) {}
+
+ LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.getClusterSize().has_value() != matchClustered) {
+ return rewriter.notifyMatchFailure(
+ op, llvm::formatv("op is {0}clustered but pattern is configured to "
+ "only match {1}clustered ops",
+ matchClustered ? "non-" : "",
+ matchClustered ? "" : "non-"));
+ }
+ auto ci = getAndValidateClusterInfo(op, subgroupSize);
+ if (failed(ci))
+ return failure();
+
+ if (ci->clusterStride != 1)
+ return rewriter.notifyMatchFailure(
+ op, "Supgroup reductions using DPP are currently only available for "
+ "clusters of contiguous lanes.");
+
+ Type valueTy = op.getType();
+ if (!valueTy.isIntOrFloat())
+ return rewriter.notifyMatchFailure(
+ op, "value type is not a compatible scalar");
----------------
kuhar wrote:
nit: Use consistent capitalization across these messages
https://github.com/llvm/llvm-project/pull/133204
More information about the Mlir-commits
mailing list