[Mlir-commits] [llvm] [mlir] [mlir][AMDGPU] Implement AMDGPU DPP operation in MLIR. (PR #89233)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri May 17 06:25:26 PDT 2024
https://github.com/stefankoncarevic updated https://github.com/llvm/llvm-project/pull/89233
>From ce8792aad6d86ad33a327fdd002d952797ea034c Mon Sep 17 00:00:00 2001
From: stefankoncarevic <skoncare at amd.com>
Date: Thu, 18 Apr 2024 13:26:20 +0000
Subject: [PATCH 1/2] [mlir][AMDGPU] Implement AMDGPU DPP operation in MLIR.
Defined AMDGPU DPP operation in mlir to represent semantics. Introduced a new
enumeration attribute for different permutations and allowed for different
types of arguments.Implemented constant attribute handling for
ROCDL::DPPUpdateOp operation. The operation now correctly accepts constant
attributes for dppCtrl, rowMask, bankMask, boundCtrl, and passes them to the
corresponding LLVM intrinsic.
---
llvm/include/llvm/IR/IntrinsicsAMDGPU.td | 6 +
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 55 +++++++
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 24 +++
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 145 ++++++++++++++++-
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 62 ++++++++
mlir/test/Conversion/AMDGPUToROCDL/dpp.mlir | 148 ++++++++++++++++++
6 files changed, 438 insertions(+), 2 deletions(-)
create mode 100644 mlir/test/Conversion/AMDGPUToROCDL/dpp.mlir
diff --git a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
index be8048ca2459c..f58aea9b351bc 100644
--- a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
+++ b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
@@ -2411,6 +2411,8 @@ def int_amdgcn_buffer_wbinvl1_vol :
// VI Intrinsics
//===----------------------------------------------------------------------===//
+// The llvm.amdgcn.mov.dpp.i32 intrinsic represents the mov.dpp operation in AMDGPU.
+// This operation is being deprecated and can be replaced with llvm.amdgcn.update.dpp.i32.
// llvm.amdgcn.mov.dpp.i32 <src> <dpp_ctrl> <row_mask> <bank_mask> <bound_ctrl>
def int_amdgcn_mov_dpp :
Intrinsic<[llvm_anyint_ty],
@@ -2420,6 +2422,10 @@ def int_amdgcn_mov_dpp :
ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<2>>,
ImmArg<ArgIndex<3>>, ImmArg<ArgIndex<4>>, IntrNoCallback, IntrNoFree]>;
+// The llvm.amdgcn.update.dpp.i32 intrinsic represents the update.dpp operation in AMDGPU.
+// It takes an old value, a source operand, a DPP control operand, a row mask, a bank mask, and a bound control.
+// This operation is equivalent to a sequence of v_mov_b32 operations.
+// It is preferred over llvm.amdgcn.mov.dpp.i32 for future use.
// llvm.amdgcn.update.dpp.i32 <old> <src> <dpp_ctrl> <row_mask> <bank_mask> <bound_ctrl>
// Should be equivalent to:
// v_mov_b32 <dest> <old>
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 3f27e1541cf38..5b6ec4e5cf8a2 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -410,6 +410,61 @@ def AMDGPU_RawBufferAtomicUminOp :
let hasVerifier = 1;
}
+def AMDGPU_DPPPerm : I32EnumAttr<"DPPPerm",
+ "The possible permutations for a DPP operation",
+ [
+ I32EnumAttrCase<"quad_perm", 0>,
+ I32EnumAttrCase<"row_shl", 1>,
+ I32EnumAttrCase<"row_shr", 2>,
+ I32EnumAttrCase<"row_ror", 3>,
+ I32EnumAttrCase<"wave_shl", 4>,
+ I32EnumAttrCase<"wave_shr", 5>,
+ I32EnumAttrCase<"wave_ror", 6>,
+ I32EnumAttrCase<"wave_rol", 7>,
+ I32EnumAttrCase<"row_mirror", 8>,
+ I32EnumAttrCase<"row_half_mirror", 9>,
+ I32EnumAttrCase<"row_bcast_15", 10>,
+ I32EnumAttrCase<"row_bcast_31", 11>
+ ]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::amdgpu";
+}
+
+def AMDGPU_DPPPermAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_DPPPerm,
+ "dpp_perm">;
+
+def AMDGPU_DPPOp : AMDGPU_Op<"dpp", [SameTypeOperands, AllTypesMatch<["result", "old", "src"]>]>,
+ Arguments<(ins AnyType:$old,
+ AnyType:$src,
+ AMDGPU_DPPPermAttr:$kind,
+ OptionalAttr<AnyAttrOf<[I32Attr, ArrayAttr, UnitAttr]>>:$permArgument,
+ DefaultValuedAttr<I32Attr, "0xf">:$row_mask,
+ DefaultValuedAttr<I32Attr, "0xf">:$bank_mask,
+ DefaultValuedAttr<BoolAttr, "false">:$bound_ctrl)> {
+ let summary = "AMDGPU DPP operation";
+ let description = [{
+ This operation represents DPP functionality in a GPU program.
+ DPP provides the following operations:
+ - Full crossbar in a group of four (`quad_perm`)
+ - Wavefront shift left by one lane (`wave_shl`)
+ - Wavefront shift right by one lane (`wave_shr`)
+ - Wavefront rotate right by one lane (`wave_ror`)
+ - Wavefront rotate left by one lane (`wave_rol`)
+ - Row shift left by 1–15 lanes (`row_shl`)
+ - Row shift right by 1–15 lanes (`row_shr`)
+ - Row rotate right by 1–15 lanes (`row_ror`)
+ - Reverse within a row (`row_mirror`)
+ - Reverse within a half-row (`row_half_mirror`)
+ - Broadcast the 15th lane of each row to the next row (`row_bcast`)
+ - Broadcast lane 31 to rows 2 and 3 (`row_bcast`)
+ }];
+ let results = (outs AnyType:$result);
+ let assemblyFormat = [{
+ $old $src $kind (`(` $permArgument^ `)`)? attr-dict `:` type($result)
+ }];
+ let hasVerifier = 1;
+}
+
def AMDGPU_LDSBarrierOp : AMDGPU_Op<"lds_barrier"> {
let summary = "Barrier that includes a wait for LDS memory operations.";
let description = [{
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 1dabf5d7979b7..0cd599bb6f0db 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -591,6 +591,30 @@ def ROCDL_RawBufferAtomicUMinOp :
let hasCustomAssemblyFormat = 1;
}
+// DPP Update intrinsic
+def ROCDL_DPPUpdateOp : ROCDL_IntrOp<"update.dpp", [], [0],
+ [AllTypesMatch<["res", "src", "old"]>], 1>,
+ Arguments<(ins LLVM_Type:$old, LLVM_Type:$src, I32Attr:$dppCtrl, I32Attr:$rowMask,
+ I32Attr:$bankMask, I1Attr:$boundCtrl)> {
+ let results = (outs LLVM_Type:$res);
+ let assemblyFormat = [{
+ attr-dict $old `,` $src `with` $dppCtrl `,` $rowMask `,` $bankMask `,` $boundCtrl `:` type($src)
+ }];
+ string llvmBuilder = [{
+ auto vdataType = moduleTranslation.convertType(op.getSrc().getType());
+ llvm::Value *args[] = {
+ moduleTranslation.lookupValue(op.getOld()),
+ moduleTranslation.lookupValue(op.getSrc()),
+ builder.getInt32(op.getDppCtrl()),
+ builder.getInt32(op.getRowMask()),
+ builder.getInt32(op.getBankMask()),
+ builder.getInt1(op.getBoundCtrl())
+ };
+ $res = createIntrinsicCall(builder,
+ llvm::Intrinsic::amdgcn_update_dpp, args, {vdataType});
+ }];
+}
+
//===---------------------------------------------------------------------===//
// 8-bit float intrinsics
//===---------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 033e66c6118f3..de0887b95b9a4 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -828,6 +828,147 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
return success();
}
+// Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
+// operation into the corresponding ROCDL instructions.
+struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
+ AMDGPUDPPLowering(LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ // Convert the source operand to the corresponding LLVM type
+ Location loc = DppOp.getLoc();
+ Value src = adaptor.getSrc();
+ Value old = adaptor.getOld();
+ Type srcType = src.getType();
+ Type oldType = old.getType();
+ auto llvmI32Type = typeConverter->convertType(rewriter.getI32Type());
+ auto llvmSrcIntType = typeConverter->convertType(
+ rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
+
+ // If the source type is less or equal to i32 or f32, use bitcast to convert
+ // it to i32.
+ auto convertOperand = [&](Value operand, Type operandType) {
+ if (llvm::isa<FloatType>(operandType)) {
+ operand =
+ rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
+ }
+
+ if (operandType.getIntOrFloatBitWidth() < 32) {
+ auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
+ 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
+ Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
+ operand = rewriter.create<LLVM::InsertElementOp>(
+ loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
+ operand = rewriter.create<LLVM::BitcastOp>(loc, llvmI32Type, operand);
+ }
+ return operand;
+ };
+
+ src = convertOperand(src, srcType);
+ old = convertOperand(old, oldType);
+
+ // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
+ enum DppCtrl : unsigned {
+ ROW_SHL0 = 0x100,
+ ROW_SHR0 = 0x110,
+ ROW_ROR0 = 0x120,
+ WAVE_SHL1 = 0x130,
+ WAVE_ROL1 = 0x134,
+ WAVE_SHR1 = 0x138,
+ WAVE_ROR1 = 0x13C,
+ ROW_MIRROR = 0x140,
+ ROW_HALF_MIRROR = 0x141,
+ BCAST15 = 0x142,
+ BCAST31 = 0x143,
+ };
+
+ auto kind = DppOp.getKind();
+ auto permArgument = DppOp.getPermArgument();
+ uint32_t DppCtrl = 0;
+
+ switch (kind) {
+
+ case DPPPerm::quad_perm:
+ if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
+ int32_t i = 0;
+ for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
+ uint32_t num = elem.getInt();
+ DppCtrl |= num << (i * 2);
+ i++;
+ }
+ }
+ break;
+ case DPPPerm::row_shl:
+ if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
+ DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
+ }
+ break;
+ case DPPPerm::row_shr:
+ if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
+ DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
+ }
+ break;
+ case DPPPerm::row_ror:
+ if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
+ DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
+ }
+ break;
+ case DPPPerm::wave_shl:
+ DppCtrl = DppCtrl::WAVE_SHL1;
+ break;
+ case DPPPerm::wave_shr:
+ DppCtrl = DppCtrl::WAVE_SHR1;
+ break;
+ case DPPPerm::wave_rol:
+ DppCtrl = DppCtrl::WAVE_ROL1;
+ break;
+ case DPPPerm::wave_ror:
+ DppCtrl = DppCtrl::WAVE_ROR1;
+ break;
+ case DPPPerm::row_mirror:
+ DppCtrl = DppCtrl::ROW_MIRROR;
+ break;
+ case DPPPerm::row_half_mirror:
+ DppCtrl = DppCtrl::ROW_HALF_MIRROR;
+ break;
+ case DPPPerm::row_bcast_15:
+ DppCtrl = DppCtrl::BCAST15;
+ break;
+ case DPPPerm::row_bcast_31:
+ DppCtrl = DppCtrl::BCAST31;
+ break;
+ }
+
+ // Check for row_mask, bank_mask, bound_ctrl if they exist and create
+ // constants
+ auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
+ auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
+ bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
+
+ // create a ROCDL_DPPMovOp instruction with the appropriate attributes
+ auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>(
+ loc, llvmI32Type, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
+
+ Value result = dppMovOp.getRes();
+ if (srcType.getIntOrFloatBitWidth() < 32) {
+ result = rewriter.create<LLVM::TruncOp>(loc, llvmSrcIntType, result);
+ }
+
+ if (!llvm::isa<IntegerType>(srcType)) {
+ result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result);
+ }
+
+ // We are replacing the AMDGPU_DPPOp instruction with the new
+ // ROCDL_DPPMovOp instruction
+ rewriter.replaceOp(DppOp, ValueRange(result));
+ return success();
+ }
+};
+
struct ConvertAMDGPUToROCDLPass
: public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
ConvertAMDGPUToROCDLPass() = default;
@@ -879,8 +1020,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
ROCDL::RawPtrBufferAtomicUminOp>,
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
ROCDL::RawPtrBufferAtomicCmpSwap>,
- LDSBarrierOpLowering, MFMAOpLowering, WMMAOpLowering,
- ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
+ AMDGPUDPPLowering, LDSBarrierOpLowering, MFMAOpLowering,
+ WMMAOpLowering, ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
PackedStochRoundFp8OpLowering>(converter, chipset);
}
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index e3beceaa3bbb5..08e912d4e43e0 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -326,6 +326,68 @@ LogicalResult MFMAOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// DPPOp
+//===----------------------------------------------------------------------===//
+LogicalResult DPPOp::verify() {
+ Type srcType = getSrc().getType();
+ if (srcType.getIntOrFloatBitWidth() > 32) {
+ return emitOpError("integer and floating point types larger than 32 bits "
+ "are not supported");
+ }
+
+ DPPPerm kind = getKind();
+ Attribute permArgument = getPermArgument().value_or(Attribute{});
+
+ switch (kind) {
+
+ case DPPPerm::quad_perm: {
+ auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
+ if (!quadPermAttr || quadPermAttr.size() != 4) {
+ return emitOpError("quad_perm attribute must have exactly 4 elements");
+ }
+ for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
+ uint32_t num = elem.getInt();
+ if (num < 0 || num > 3) {
+ return emitOpError(
+ "Each element of quad_perm must be in the range [0, 3]");
+ }
+ }
+ } break;
+
+ case DPPPerm::row_shl:
+ case DPPPerm::row_shr:
+ case DPPPerm::row_ror: {
+ if (!permArgument) {
+ return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) +
+ "' value not specified");
+ }
+ if (auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
+ uint32_t attrValue = intAttr.getInt();
+ if (attrValue < 1 || attrValue > 15) {
+ return emitOpError("Attribute value must be between 1 and 15");
+ }
+ }
+ } break;
+
+ case DPPPerm::wave_shl:
+ case DPPPerm::wave_shr:
+ case DPPPerm::wave_rol:
+ case DPPPerm::wave_ror:
+ case DPPPerm::row_mirror:
+ case DPPPerm::row_half_mirror:
+ case DPPPerm::row_bcast_15:
+ case DPPPerm::row_bcast_31: {
+ if (permArgument && !isa<UnitAttr>(permArgument)) {
+ return emitOpError("Expected unit attribute for permArgument, but found "
+ "non-trivial argument");
+ }
+ break;
+ }
+ }
+ return success();
+}
+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
#define GET_ATTRDEF_CLASSES
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/dpp.mlir b/mlir/test/Conversion/AMDGPUToROCDL/dpp.mlir
new file mode 100644
index 0000000000000..c982342c879f5
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/dpp.mlir
@@ -0,0 +1,148 @@
+// RUN: mlir-opt -convert-amdgpu-to-rocdl=chipset=gfx908 %s | FileCheck %s
+// RUN: mlir-opt -convert-amdgpu-to-rocdl=chipset=gfx90a %s | FileCheck %s
+// RUN: mlir-opt -convert-amdgpu-to-rocdl=chipset=gfx942 %s | FileCheck %s
+
+func.func @test_dpp(%arg0: i32, %arg1: i32) -> i32 {
+ // CHECK-LABEL: func @test_dpp
+ // CHECK: rocdl.update.dpp %arg0, %arg1 with 257, 10, 15, false : i32
+ // CHECK: return %0 : i32
+ %0 = amdgpu.dpp %arg0 %arg1 row_shl ( 0x1 : i32 ) { row_mask = 0xa : i32, bound_ctrl = false } : i32
+ return %0 : i32
+}
+
+func.func @quad_dpp(%arg0: i32, %arg1: i32) -> i32 {
+ // CHECK-LABEL: func @quad_dpp
+ // CHECK: rocdl.update.dpp %arg0, %arg1 with 145, 1, 1, true : i32
+ // CHECK: return %0 : i32
+ %0 = amdgpu.dpp %arg0 %arg1 quad_perm ( [1,0,1,2] ) { row_mask = 0x1 : i32, bank_mask = 0x1 : i32, bound_ctrl = true } : i32
+ return %0 : i32
+}
+
+func.func @quad_perm_dpp(%arg0: i32, %arg1: i32) -> i32 {
+ // CHECK-LABEL: func @quad_perm_dpp
+ // CHECK: rocdl.update.dpp %arg0, %arg1 with 88, 15, 15, false : i32
+ // CHECK: return %0 : i32
+ %0 = amdgpu.dpp %arg0 %arg1 quad_perm ( [0,2,1,1] ) : i32
+ return %0 : i32
+}
+
+func.func @wave_shr_dpp(%arg0: i32, %arg1: i32) -> i32 {
+ // CHECK-LABEL: func @wave_shr_dpp
+ // CHECK: rocdl.update.dpp %arg0, %arg1 with 312, 10, 1, true : i32
+ // CHECK: return %0 : i32
+ %0 = amdgpu.dpp %arg0 %arg1 wave_shr { row_mask = 0xa : i32, bank_mask = 0x1 : i32, bound_ctrl = true } : i32
+ return %0 : i32
+}
+
+func.func @row_bcast_dpp(%arg0: i32, %arg1: i32) -> i32 {
+ // CHECK-LABEL: func @row_bcast_dpp
+ // CHECK: rocdl.update.dpp %arg0, %arg1 with 323, 4, 1, false : i32
+ // CHECK: return %0 : i32
+ %0 = amdgpu.dpp %arg0 %arg1 row_bcast_31 { row_mask = 0x4 : i32, bank_mask = 0x1 : i32} : i32
+ return %0 : i32
+}
+
+func.func @row_bcast_dpp_f32(%arg0: f32, %arg1: f32) -> f32 {
+ // CHECK-LABEL: func @row_bcast_dpp_f32
+ // CHECK: llvm.bitcast %arg1 : f32 to i32
+ // CHECK: llvm.bitcast %arg0 : f32 to i32
+ // CHECK: rocdl.update.dpp %1, %0 with 322, 15, 15, true : i32
+ // CHECK: llvm.bitcast %2 : i32 to f32
+ // CHECK: return %3 : f32
+ %0 = amdgpu.dpp %arg0 %arg1 row_bcast_15 { bound_ctrl = true } : f32
+ return %0 : f32
+}
+
+func.func @row_half_mirror_update_dpp(%arg0: i32, %arg1: i32) -> i32 {
+ // CHECK-LABEL: func @row_half_mirror_update_dpp
+ // CHECK: rocdl.update.dpp %arg0, %arg1 with 321, 15, 1, false : i32
+ // CHECK: return %0 : i32
+%0 = amdgpu.dpp %arg0 %arg1 row_half_mirror { bank_mask = 0x1 : i32 } : i32
+ return %0 : i32
+}
+
+func.func @wave_rol_update_dpp(%arg0: i32, %arg1: i32) -> i32 {
+ // CHECK-LABEL: func @wave_rol_update_dpp
+ // CHECK: rocdl.update.dpp %arg0, %arg1 with 308, 10, 1, false : i32
+ // CHECK: return %0 : i32
+ %0 = amdgpu.dpp %arg0 %arg1 wave_rol { row_mask = 0xa : i32, bank_mask = 0x1 : i32 } : i32
+ return %0 : i32
+}
+
+func.func @test_dpp_f32(%arg0: f32, %arg1: f32) -> f32 {
+ // CHECK-LABEL: func @test_dpp_f32
+ // CHECK: llvm.bitcast %arg1 : f32 to i32
+ // CHECK: llvm.bitcast %arg0 : f32 to i32
+ // CHECK: rocdl.update.dpp %1, %0 with 320, 1, 4, true : i32
+ // CHECK: llvm.bitcast %2 : i32 to f32
+ // CHECK: return %3 : f32
+ %0 = amdgpu.dpp %arg0 %arg1 row_mirror { row_mask = 0x1 : i32, bank_mask = 0x4 : i32, bound_ctrl = true } : f32
+ return %0 : f32
+}
+
+func.func @test_dpp_f16(%arg0: f16, %arg1: f16) -> f16 {
+ // CHECK-LABEL: func @test_dpp_f16
+ // CHECK: llvm.bitcast %arg1 : f16 to i16
+ // CHECK: llvm.mlir.undef : vector<2xi16>
+ // CHECK: llvm.mlir.constant(0 : i32) : i32
+ // CHECK: llvm.insertelement %0, %1[%2 : i32] : vector<2xi16>
+ // CHECK: llvm.bitcast %3 : vector<2xi16> to i32
+ // CHECK: llvm.bitcast %arg0 : f16 to i16
+ // CHECK: llvm.mlir.undef : vector<2xi16>
+ // CHECK: llvm.mlir.constant(0 : i32) : i32
+ // CHECK: llvm.insertelement %5, %6[%7 : i32] : vector<2xi16>
+ // CHECK: llvm.bitcast %8 : vector<2xi16> to i32
+ // CHECK: rocdl.update.dpp %9, %4 with 273, 15, 3, false : i32
+ // CHECK: llvm.trunc %10 : i32 to i16
+ // CHECK: llvm.bitcast %11 : i16 to f16
+ // CHECK: return %12 : f16
+ %0 = amdgpu.dpp %arg0 %arg1 row_shr ( 0x1 : i32 ){ bank_mask = 0x3 : i32 } : f16
+ return %0 : f16
+}
+
+func.func @row_shl_dpp_i16(%arg0: i16, %arg1: i16) -> i16 {
+ // CHECK-LABEL: func @row_shl_dpp_i16
+ // CHECK: llvm.mlir.undef : vector<2xi16>
+ // CHECK: llvm.mlir.constant(0 : i32) : i32
+ // CHECK: llvm.insertelement %arg1, %0[%1 : i32] : vector<2xi16>
+ // CHECK: llvm.bitcast %2 : vector<2xi16> to i32
+ // CHECK: llvm.mlir.undef : vector<2xi16>
+ // CHECK: llvm.mlir.constant(0 : i32) : i32
+ // CHECK: llvm.insertelement %arg0, %4[%5 : i32] : vector<2xi16>
+ // CHECK: llvm.bitcast %6 : vector<2xi16> to i32
+ // CHECK: rocdl.update.dpp %7, %3 with 298, 10, 1, false : i32
+ // CHECK: llvm.trunc %8 : i32 to i16
+ // CHECK: return %9 : i16
+ %0 = amdgpu.dpp %arg0 %arg1 row_ror ( 0xa : i32 ) { row_mask = 0xa : i32, bank_mask = 0x1 : i32 } : i16
+ return %0 : i16
+}
+
+func.func @quad_perm_update_dpp_f32(%arg0: f32, %arg1: f32) -> f32 {
+ // CHECK-LABEL: func @quad_perm_update_dpp_f32
+ // CHECK: llvm.bitcast %arg1 : f32 to i32
+ // CHECK: llvm.bitcast %arg0 : f32 to i32
+ // CHECK: rocdl.update.dpp %1, %0 with 88, 15, 1, false : i32
+ // CHECK: llvm.bitcast %2 : i32 to f32
+ // CHECK: return %3 : f32
+ %0 = amdgpu.dpp %arg0 %arg1 quad_perm ( [0,2,1,1] ) { bank_mask = 0x1 : i32 } : f32
+ return %0 : f32
+}
+
+func.func @row_bcast_update_dpp_f16(%arg0: f16, %arg1: f16) -> f16 {
+ // CHECK-LABEL: func @row_bcast_update_dpp_f16
+ // CHECK: llvm.bitcast %arg1 : f16 to i16
+ // CHECK: llvm.mlir.undef : vector<2xi16>
+ // CHECK: llvm.mlir.constant(0 : i32) : i32
+ // CHECK: llvm.insertelement %0, %1[%2 : i32] : vector<2xi16>
+ // CHECK: llvm.bitcast %arg0 : f16 to i16
+ // CHECK: llvm.mlir.undef : vector<2xi16>
+ // CHECK: llvm.mlir.constant(0 : i32) : i32
+ // CHECK llvm.insertelement %5, %6[%7 : i32] : vector<2xi16>
+ // CHECK: llvm.bitcast %8 : vector<2xi16> to i32
+ // CHECK: rocdl.update.dpp %9, %4 with 322, 15, 15, true : i32
+ // CHECK: llvm.trunc %10 : i32 to i16
+ // CHECK: llvm.bitcast %11 : i16 to f16
+ // CHECK: return %12 : f16
+ %0 = amdgpu.dpp %arg0 %arg1 row_bcast_15 { bound_ctrl = true } : f16
+ return %0 : f16
+}
>From d0690311ca7b03d3a717f96ab0fd0f5783bdb6c8 Mon Sep 17 00:00:00 2001
From: stefankoncarevic <skoncare at amd.com>
Date: Thu, 16 May 2024 13:30:40 +0000
Subject: [PATCH 2/2] [mlir][AMDGPU] "Added support for 64-bit operands in
ROCDL::DPPUpdateOp operation."
---
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 2 +-
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 38 +++++----
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 4 +-
mlir/test/Conversion/AMDGPUToROCDL/dpp.mlir | 77 ++++++++-----------
4 files changed, 60 insertions(+), 61 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 0cd599bb6f0db..7e2710d3ac367 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -609,7 +609,7 @@ def ROCDL_DPPUpdateOp : ROCDL_IntrOp<"update.dpp", [], [0],
builder.getInt32(op.getRowMask()),
builder.getInt32(op.getBankMask()),
builder.getInt1(op.getBoundCtrl())
- };
+ };
$res = createIntrinsicCall(builder,
llvm::Intrinsic::amdgcn_update_dpp, args, {vdataType});
}];
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index de0887b95b9a4..b24f70c824ad4 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -845,25 +845,34 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
Value old = adaptor.getOld();
Type srcType = src.getType();
Type oldType = old.getType();
- auto llvmI32Type = typeConverter->convertType(rewriter.getI32Type());
+ Type llvmType = nullptr;
+ if (srcType.getIntOrFloatBitWidth() < 32) {
+ llvmType = rewriter.getI32Type();
+ } else if (isa<FloatType>(srcType)) {
+ llvmType = (srcType.getIntOrFloatBitWidth() == 32)
+ ? rewriter.getF32Type()
+ : rewriter.getF64Type();
+ } else if (isa<IntegerType>(srcType)) {
+ llvmType = (srcType.getIntOrFloatBitWidth() == 32)
+ ? rewriter.getI32Type()
+ : rewriter.getI64Type();
+ }
auto llvmSrcIntType = typeConverter->convertType(
rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
- // If the source type is less or equal to i32 or f32, use bitcast to convert
- // it to i32.
+ // If the source type is less of 32, use bitcast to convert it to i32.
auto convertOperand = [&](Value operand, Type operandType) {
- if (llvm::isa<FloatType>(operandType)) {
- operand =
- rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
- }
-
- if (operandType.getIntOrFloatBitWidth() < 32) {
+ if (operandType.getIntOrFloatBitWidth() <= 16) {
+ if (llvm::isa<FloatType>(operandType)) {
+ operand =
+ rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
+ }
auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
operand = rewriter.create<LLVM::InsertElementOp>(
loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
- operand = rewriter.create<LLVM::BitcastOp>(loc, llvmI32Type, operand);
+ operand = rewriter.create<LLVM::BitcastOp>(loc, llvmType, operand);
}
return operand;
};
@@ -951,15 +960,14 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
// create a ROCDL_DPPMovOp instruction with the appropriate attributes
auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>(
- loc, llvmI32Type, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
+ loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
Value result = dppMovOp.getRes();
if (srcType.getIntOrFloatBitWidth() < 32) {
result = rewriter.create<LLVM::TruncOp>(loc, llvmSrcIntType, result);
- }
-
- if (!llvm::isa<IntegerType>(srcType)) {
- result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result);
+ if (!llvm::isa<IntegerType>(srcType)) {
+ result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result);
+ }
}
// We are replacing the AMDGPU_DPPOp instruction with the new
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 08e912d4e43e0..c1a785fb25478 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -331,8 +331,8 @@ LogicalResult MFMAOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult DPPOp::verify() {
Type srcType = getSrc().getType();
- if (srcType.getIntOrFloatBitWidth() > 32) {
- return emitOpError("integer and floating point types larger than 32 bits "
+ if (srcType.getIntOrFloatBitWidth() > 64) {
+ return emitOpError("integer and floating point types larger than 64 bits "
"are not supported");
}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/dpp.mlir b/mlir/test/Conversion/AMDGPUToROCDL/dpp.mlir
index c982342c879f5..14691e73e62d7 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/dpp.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/dpp.mlir
@@ -18,14 +18,6 @@ func.func @quad_dpp(%arg0: i32, %arg1: i32) -> i32 {
return %0 : i32
}
-func.func @quad_perm_dpp(%arg0: i32, %arg1: i32) -> i32 {
- // CHECK-LABEL: func @quad_perm_dpp
- // CHECK: rocdl.update.dpp %arg0, %arg1 with 88, 15, 15, false : i32
- // CHECK: return %0 : i32
- %0 = amdgpu.dpp %arg0 %arg1 quad_perm ( [0,2,1,1] ) : i32
- return %0 : i32
-}
-
func.func @wave_shr_dpp(%arg0: i32, %arg1: i32) -> i32 {
// CHECK-LABEL: func @wave_shr_dpp
// CHECK: rocdl.update.dpp %arg0, %arg1 with 312, 10, 1, true : i32
@@ -34,25 +26,6 @@ func.func @wave_shr_dpp(%arg0: i32, %arg1: i32) -> i32 {
return %0 : i32
}
-func.func @row_bcast_dpp(%arg0: i32, %arg1: i32) -> i32 {
- // CHECK-LABEL: func @row_bcast_dpp
- // CHECK: rocdl.update.dpp %arg0, %arg1 with 323, 4, 1, false : i32
- // CHECK: return %0 : i32
- %0 = amdgpu.dpp %arg0 %arg1 row_bcast_31 { row_mask = 0x4 : i32, bank_mask = 0x1 : i32} : i32
- return %0 : i32
-}
-
-func.func @row_bcast_dpp_f32(%arg0: f32, %arg1: f32) -> f32 {
- // CHECK-LABEL: func @row_bcast_dpp_f32
- // CHECK: llvm.bitcast %arg1 : f32 to i32
- // CHECK: llvm.bitcast %arg0 : f32 to i32
- // CHECK: rocdl.update.dpp %1, %0 with 322, 15, 15, true : i32
- // CHECK: llvm.bitcast %2 : i32 to f32
- // CHECK: return %3 : f32
- %0 = amdgpu.dpp %arg0 %arg1 row_bcast_15 { bound_ctrl = true } : f32
- return %0 : f32
-}
-
func.func @row_half_mirror_update_dpp(%arg0: i32, %arg1: i32) -> i32 {
// CHECK-LABEL: func @row_half_mirror_update_dpp
// CHECK: rocdl.update.dpp %arg0, %arg1 with 321, 15, 1, false : i32
@@ -69,17 +42,46 @@ func.func @wave_rol_update_dpp(%arg0: i32, %arg1: i32) -> i32 {
return %0 : i32
}
+func.func @row_bcast_dpp_f32(%arg0: f32, %arg1: f32) -> f32 {
+ // CHECK-LABEL: func @row_bcast_dpp_f32
+ // CHECK: rocdl.update.dpp %arg0, %arg1 with 322, 15, 15, true : f32
+ // CHECK: return %0 : f32
+ %0 = amdgpu.dpp %arg0 %arg1 row_bcast_15 { bound_ctrl = true } : f32
+ return %0 : f32
+}
+
func.func @test_dpp_f32(%arg0: f32, %arg1: f32) -> f32 {
// CHECK-LABEL: func @test_dpp_f32
- // CHECK: llvm.bitcast %arg1 : f32 to i32
- // CHECK: llvm.bitcast %arg0 : f32 to i32
- // CHECK: rocdl.update.dpp %1, %0 with 320, 1, 4, true : i32
- // CHECK: llvm.bitcast %2 : i32 to f32
- // CHECK: return %3 : f32
+ // CHECK: rocdl.update.dpp %arg0, %arg1 with 320, 1, 4, true : f32
+ // CHECK: return %0 : f32
%0 = amdgpu.dpp %arg0 %arg1 row_mirror { row_mask = 0x1 : i32, bank_mask = 0x4 : i32, bound_ctrl = true } : f32
return %0 : f32
}
+func.func @quad_perm_update_dpp_f32(%arg0: f32, %arg1: f32) -> f32 {
+ // CHECK-LABEL: func @quad_perm_update_dpp_f32
+ // CHECK: rocdl.update.dpp %arg0, %arg1 with 88, 15, 1, false : f32
+ // CHECK: return %0 : f32
+ %0 = amdgpu.dpp %arg0 %arg1 quad_perm ( [0,2,1,1] ) { bank_mask = 0x1 : i32 } : f32
+ return %0 : f32
+}
+
+func.func @quad_perm_dpp(%arg0: i64, %arg1: i64) -> i64 {
+ // CHECK-LABEL: func @quad_perm_dpp
+ // CHECK: rocdl.update.dpp %arg0, %arg1 with 88, 15, 15, false : i64
+ // CHECK: return %0 : i64
+ %0 = amdgpu.dpp %arg0 %arg1 quad_perm ( [0,2,1,1] ) : i64
+ return %0 : i64
+}
+
+func.func @row_bcast_dpp(%arg0: f64, %arg1: f64) -> f64 {
+ // CHECK-LABEL: func @row_bcast_dpp
+ // CHECK: rocdl.update.dpp %arg0, %arg1 with 323, 4, 1, false : f64
+ // CHECK: return %0 : f64
+ %0 = amdgpu.dpp %arg0 %arg1 row_bcast_31 { row_mask = 0x4 : i32, bank_mask = 0x1 : i32} : f64
+ return %0 : f64
+}
+
func.func @test_dpp_f16(%arg0: f16, %arg1: f16) -> f16 {
// CHECK-LABEL: func @test_dpp_f16
// CHECK: llvm.bitcast %arg1 : f16 to i16
@@ -117,17 +119,6 @@ func.func @row_shl_dpp_i16(%arg0: i16, %arg1: i16) -> i16 {
return %0 : i16
}
-func.func @quad_perm_update_dpp_f32(%arg0: f32, %arg1: f32) -> f32 {
- // CHECK-LABEL: func @quad_perm_update_dpp_f32
- // CHECK: llvm.bitcast %arg1 : f32 to i32
- // CHECK: llvm.bitcast %arg0 : f32 to i32
- // CHECK: rocdl.update.dpp %1, %0 with 88, 15, 1, false : i32
- // CHECK: llvm.bitcast %2 : i32 to f32
- // CHECK: return %3 : f32
- %0 = amdgpu.dpp %arg0 %arg1 quad_perm ( [0,2,1,1] ) { bank_mask = 0x1 : i32 } : f32
- return %0 : f32
-}
-
func.func @row_bcast_update_dpp_f16(%arg0: f16, %arg1: f16) -> f16 {
// CHECK-LABEL: func @row_bcast_update_dpp_f16
// CHECK: llvm.bitcast %arg1 : f16 to i16
More information about the Mlir-commits
mailing list