[Mlir-commits] [mlir] [AMDGPU] Implement gpu.subgroup_reduce with DPP intrinsics on AMD GPUs (PR #133204)
Krzysztof Drewniak
llvmlistbot at llvm.org
Mon Apr 14 09:44:40 PDT 2025
================
@@ -362,6 +365,119 @@ struct VectorSubgroupReduceToShuffles final
unsigned shuffleBitwidth = 0;
bool matchClustered = false;
};
+
+Value createSubgroupDPPReduction(OpBuilder &b, Location loc, Value input,
+ gpu::AllReduceOperation mode,
+ const ClusterInfo &ci,
+ amdgpu::Chipset chipset) {
+ Value result = input;
+ Value dppResult;
+ const int allRows = 0xf;
+ const int allBanks = 0xf;
+ const bool boundCtrl = true;
+ if (ci.clusterSize >= 2) {
+ auto permArg = b.getI32ArrayAttr({1, 0, 3, 2});
+ dppResult =
+ b.create<amdgpu::DPPOp>(loc, result.getType(), result, result,
+ amdgpu::DPPPerm::quad_perm, permArg, allRows, allBanks, boundCtrl);
+ result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
+ result, dppResult);
+ }
+
+ if (ci.clusterSize >= 4) {
+ auto permArg = b.getI32ArrayAttr({2, 3, 0, 1});
+ dppResult =
+ b.create<amdgpu::DPPOp>(loc, result.getType(), result, result,
+ amdgpu::DPPPerm::quad_perm, permArg, allRows, allBanks, boundCtrl);
+ result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
+ result, dppResult);
+ }
+
+ if (ci.clusterSize >= 8) {
+ dppResult = b.create<amdgpu::DPPOp>(
+ loc, result.getType(), result, result, amdgpu::DPPPerm::row_half_mirror,
+ b.getUnitAttr(), allRows, allBanks, boundCtrl);
+ result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
+ result, dppResult);
+ }
+
+ if (ci.clusterSize >= 16) {
+ dppResult = b.create<amdgpu::DPPOp>(
+ loc, result.getType(), result, result, amdgpu::DPPPerm::row_mirror,
+ b.getUnitAttr(), allRows, allBanks, boundCtrl);
+ result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
+ result, dppResult);
+ }
+
+ Value lane31 = b.create<LLVM::ConstantOp>(loc, b.getI32Type(), 31);
+ Value lane63 = b.create<LLVM::ConstantOp>(loc, b.getI32Type(), 63);
+ if (ci.clusterSize >= 32) {
+ if (chipset.majorVersion <= 9) {
+ auto permArg = b.getI32IntegerAttr(15);
+ dppResult = b.create<amdgpu::DPPOp>(
+ loc, result.getType(), result, result, amdgpu::DPPPerm::row_bcast_15,
+ b.getUnitAttr(), 0xa, allBanks, false);
+ } else if (chipset.majorVersion == 10) {
----------------
krzysz00 wrote:
`>= 10` - gfx11 and gfx12 also have this, per the predicate on this instruction being `IsGfx10Plus`
https://github.com/llvm/llvm-project/pull/133204
More information about the Mlir-commits
mailing list