[Mlir-commits] [mlir] [AMDGPU] Implement gpu.subgroup_reduce with DPP intrinsics on AMD GPUs (PR #133204)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 28 15:46:43 PDT 2025
================
@@ -0,0 +1,203 @@
+//===- GPUToAMDGPU.cpp - GPU to AMDGPU dialect conversion -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/GPUToAMDGPU/GPUToAMDGPU.h"
+
+#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Pass/Pass.h"
+
+#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/MathExtras.h"
+#include <cassert>
+#include <cstdint>
+
+#include "../LLVMCommon/MemRefDescriptor.h"
+
+#include "llvm/ADT/STLExtras.h"
+#include <optional>
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTGPUTOAMDGPUPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct ClusterInfo {
+ unsigned clusterStride;
+ unsigned clusterSize;
+ unsigned subgroupSize;
+};
+
+static FailureOr<ClusterInfo>
+getAndValidateClusterInfo(gpu::SubgroupReduceOp op, unsigned subgroupSize) {
+ assert(llvm::isPowerOf2_32(subgroupSize));
+
+ std::optional<uint32_t> clusterSize = op.getClusterSize();
+ assert(!clusterSize ||
+ llvm::isPowerOf2_32(*clusterSize)); // Verifier should've caught this.
+ if (clusterSize && *clusterSize > subgroupSize)
+ return op.emitOpError()
+ << "cluster size " << *clusterSize
+ << " is greater than subgroup size " << subgroupSize;
+ unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize);
+
+ auto clusterStride = op.getClusterStride();
+ assert(llvm::isPowerOf2_32(clusterStride)); // Verifier should've caught this.
+ if (clusterStride >= subgroupSize)
+ return op.emitOpError()
+ << "cluster stride " << clusterStride
+ << " is not less than subgroup size " << subgroupSize;
+
+ return ClusterInfo{clusterStride, effectiveClusterSize, subgroupSize};
+}
+
+Value createSubgroupDPPReduction(OpBuilder &b, Location loc, Value input,
+ gpu::AllReduceOperation mode,
+ const ClusterInfo &ci) {
+ Value result = input;
+ if (ci.clusterSize >= 2) {
+ auto permArg = b.getIntegerAttr(b.getIntegerType(32), 1);
+ Value dppResult =
+ b.create<amdgpu::DPPOp>(loc, result.getType(), result, result,
+ amdgpu::DPPPerm::row_shr, permArg);
+ result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
+ result, dppResult);
+ }
+
+ if (ci.clusterSize >= 4) {
+ auto permArg = b.getIntegerAttr(b.getIntegerType(32), 2);
+ Value dppResult =
+ b.create<amdgpu::DPPOp>(loc, result.getType(), result, result,
+ amdgpu::DPPPerm::row_shr, permArg);
+ result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
+ result, dppResult);
+ }
+
+ if (ci.clusterSize >= 8) {
+ Value dppResult = b.create<amdgpu::DPPOp>(
+ loc, result.getType(), result, result, amdgpu::DPPPerm::row_half_mirror,
+ b.getUnitAttr());
+ result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
+ result, dppResult);
+ }
+
+ if (ci.clusterSize >= 16) {
+ Value dppResult =
+ b.create<amdgpu::DPPOp>(loc, result.getType(), result, result,
+ amdgpu::DPPPerm::row_mirror, b.getUnitAttr());
+ result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
+ result, dppResult);
+ }
+
+ if (ci.clusterSize >= 32) {
+ // auto permArg = builder.getInt32(15);
+ // auto rowMask = builder.getInt32("0xa");
+ // auto bankMask = builder.getInt32("0xf");
+ // auto boundCtrl = builder.getBoolAttr(false);
+ auto permArg = b.getIntegerAttr(b.getIntegerType(32), 15);
+ Value dppResult = b.create<amdgpu::DPPOp>(
+ loc, result.getType(), result, result, amdgpu::DPPPerm::row_bcast_15,
+ b.getUnitAttr(), 10, 15, false);
+ result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
+ result, dppResult);
+ }
+
+ if (ci.clusterSize == 64) {
+ // auto permArg = builder.getInt32(31);
+ // auto rowMask = builder.getInt32("0xc");
+ // auto bankMask = builder.getInt32("0xf");
+ // auto boundCtrl = builder.getBoolAttr(false);
+ auto permArg = b.getIntegerAttr(b.getIntegerType(32), 31);
+ Value dppResult = b.create<amdgpu::DPPOp>(
+ loc, result.getType(), result, result, amdgpu::DPPPerm::row_bcast_31,
+ b.getUnitAttr(), 12, 15, false);
+ result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
+ result, dppResult);
+ }
+
+ // // read lane 63 with the final result.
+ // auto lane = b.getIntegerAttr(b.getIntegerType(32), 63);
+ // result = b.create<ROCDL::ReadLaneOp>(loc, input.getType(), result, lane);
+ assert(result.getType() == input.getType());
+ return result;
+}
+
+struct ScalarSubgroupReduceToShuffles final
+ : OpRewritePattern<gpu::SubgroupReduceOp> {
+ ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
+ bool matchClustered, PatternBenefit benefit)
+ : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
+ matchClustered(matchClustered) {}
+
+ LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
+ PatternRewriter &rewriter) const override {
+ llvm::errs() << "ScalarSubgroupReduceToShuffles" << "\n";
+ if (op.getClusterSize().has_value() != matchClustered) {
+ return rewriter.notifyMatchFailure(
+ op, llvm::formatv("op is {0}clustered but pattern is configured to "
+ "only match {1}clustered ops",
+ matchClustered ? "non-" : "",
+ matchClustered ? "" : "non-"));
+ }
+
+ auto ci = getAndValidateClusterInfo(op, subgroupSize);
+ if (failed(ci))
+ return failure();
+
+ Location loc = op.getLoc();
+ rewriter.replaceOp(op, createSubgroupDPPReduction(
+ rewriter, loc, op.getValue(), op.getOp(), *ci));
+ return success();
+ }
+
+private:
+ unsigned subgroupSize = 0;
+ bool matchClustered = false;
+};
+
+struct ConvertGPUToAMDGPUPass
+ : public impl::ConvertGPUToAMDGPUPassBase<ConvertGPUToAMDGPUPass> {
+ using Base::Base;
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
----------------
Muzammiluddin-Syed-ECE wrote:
moved it back
https://github.com/llvm/llvm-project/pull/133204
More information about the Mlir-commits
mailing list