[Mlir-commits] [mlir] [AMDGPU] Implement gpu.subgroup_reduce with DPP intrinsics on AMD GPUs (PR #133204)

Jakub Kuderski llvmlistbot at llvm.org
Thu Apr 3 18:22:13 PDT 2025


================
@@ -362,6 +364,106 @@ struct VectorSubgroupReduceToShuffles final
   unsigned shuffleBitwidth = 0;
   bool matchClustered = false;
 };
+
+Value createSubgroupDPPReduction(OpBuilder &b, Location loc, Value input,
+                                 gpu::AllReduceOperation mode,
+                                 const ClusterInfo &ci) {
+  Value result = input;
+  if (ci.clusterSize >= 2) {
+    auto permArg = b.getIntegerAttr(b.getIntegerType(32), 1);
+    Value dppResult =
+        b.create<amdgpu::DPPOp>(loc, result.getType(), result, result,
+                                amdgpu::DPPPerm::row_shl, permArg);
+    result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
+                                        result, dppResult);
+  }
+
+  if (ci.clusterSize >= 4) {
+    auto permArg = b.getIntegerAttr(b.getIntegerType(32), 2);
+    Value dppResult =
+        b.create<amdgpu::DPPOp>(loc, result.getType(), result, result,
+                                amdgpu::DPPPerm::row_shl, permArg);
+    result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
+                                        result, dppResult);
+  }
+
+  if (ci.clusterSize >= 8) {
+    Value dppResult = b.create<amdgpu::DPPOp>(
+        loc, result.getType(), result, result, amdgpu::DPPPerm::row_half_mirror,
+        b.getUnitAttr());
+    result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
+                                        result, dppResult);
+  }
+
+  if (ci.clusterSize >= 16) {
+    Value dppResult =
+        b.create<amdgpu::DPPOp>(loc, result.getType(), result, result,
+                                amdgpu::DPPPerm::row_mirror, b.getUnitAttr());
+    result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
+                                        result, dppResult);
+  }
+
+  const int allRows = 0xf;
+  const int allBanks = 0xf;
+  auto int32Type = IntegerType::get(b.getContext(), 32);
----------------
kuhar wrote:

`b.getI32Type()`

https://github.com/llvm/llvm-project/pull/133204


More information about the Mlir-commits mailing list