[Mlir-commits] [mlir] 1164e4a - [mlir][AMDGPU] Implement AMDGPU DPP operation in MLIR. (#89233)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Aug 16 09:19:44 PDT 2024
Author: stefankoncarevic
Date: 2024-08-16T11:19:39-05:00
New Revision: 1164e4aef2844ac3d35153ffe6376db9abda704a
URL: https://github.com/llvm/llvm-project/commit/1164e4aef2844ac3d35153ffe6376db9abda704a
DIFF: https://github.com/llvm/llvm-project/commit/1164e4aef2844ac3d35153ffe6376db9abda704a.diff
LOG: [mlir][AMDGPU] Implement AMDGPU DPP operation in MLIR. (#89233)
Defined AMDGPU DPP operation in mlir to represent semantics. Introduced
a new enumeration attribute for different permutations and allowed for
different types of arguments. Implemented constant attribute handling
for ROCDL::DPPMovOp operation. The operation now correctly accepts
constant attributes for dppCtrl, rowMask, bankMask, boundCtrl, and
passes them to the corresponding LLVM intrinsic.
Added:
mlir/test/Conversion/AMDGPUToROCDL/dpp.mlir
Modified:
llvm/include/llvm/IR/IntrinsicsAMDGPU.td
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
index 8c25467cc5e4b1..7c93c19a410e43 100644
--- a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
+++ b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
@@ -2343,6 +2343,8 @@ def int_amdgcn_buffer_wbinvl1_vol :
// VI Intrinsics
//===----------------------------------------------------------------------===//
+// The llvm.amdgcn.mov.dpp.i32 intrinsic represents the mov.dpp operation in AMDGPU.
+// This operation is being deprecated and can be replaced with llvm.amdgcn.update.dpp.i32.
// llvm.amdgcn.mov.dpp.i32 <src> <dpp_ctrl> <row_mask> <bank_mask> <bound_ctrl>
def int_amdgcn_mov_dpp :
Intrinsic<[llvm_anyint_ty],
@@ -2352,6 +2354,10 @@ def int_amdgcn_mov_dpp :
ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<2>>,
ImmArg<ArgIndex<3>>, ImmArg<ArgIndex<4>>, IntrNoCallback, IntrNoFree]>;
+// The llvm.amdgcn.update.dpp.i32 intrinsic represents the update.dpp operation in AMDGPU.
+// It takes an old value, a source operand, a DPP control operand, a row mask, a bank mask, and a bound control.
+// This operation is equivalent to a sequence of v_mov_b32 operations.
+// It is preferred over llvm.amdgcn.mov.dpp.i32 for future use.
// llvm.amdgcn.update.dpp.i32 <old> <src> <dpp_ctrl> <row_mask> <bank_mask> <bound_ctrl>
// Should be equivalent to:
// v_mov_b32 <dest> <old>
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index dddbd837c1cbc7..97e0580c898080 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -402,6 +402,61 @@ def AMDGPU_RawBufferAtomicUminOp :
let hasVerifier = 1;
}
+def AMDGPU_DPPPerm : I32EnumAttr<"DPPPerm",
+ "The possible permutations for a DPP operation",
+ [
+ I32EnumAttrCase<"quad_perm", 0>,
+ I32EnumAttrCase<"row_shl", 1>,
+ I32EnumAttrCase<"row_shr", 2>,
+ I32EnumAttrCase<"row_ror", 3>,
+ I32EnumAttrCase<"wave_shl", 4>,
+ I32EnumAttrCase<"wave_shr", 5>,
+ I32EnumAttrCase<"wave_ror", 6>,
+ I32EnumAttrCase<"wave_rol", 7>,
+ I32EnumAttrCase<"row_mirror", 8>,
+ I32EnumAttrCase<"row_half_mirror", 9>,
+ I32EnumAttrCase<"row_bcast_15", 10>,
+ I32EnumAttrCase<"row_bcast_31", 11>
+ ]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::amdgpu";
+}
+
+def AMDGPU_DPPPermAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_DPPPerm,
+ "dpp_perm">;
+
+def AMDGPU_DPPOp : AMDGPU_Op<"dpp", [SameTypeOperands, AllTypesMatch<["result", "old", "src"]>]>,
+ Arguments<(ins AnyType:$old,
+ AnyType:$src,
+ AMDGPU_DPPPermAttr:$kind,
+ OptionalAttr<AnyAttrOf<[I32Attr, ArrayAttr, UnitAttr]>>:$permArgument,
+ DefaultValuedAttr<I32Attr, "0xf">:$row_mask,
+ DefaultValuedAttr<I32Attr, "0xf">:$bank_mask,
+ DefaultValuedAttr<BoolAttr, "false">:$bound_ctrl)> {
+ let summary = "AMDGPU DPP operation";
+ let description = [{
+ This operation represents DPP functionality in a GPU program.
+ DPP provides the following operations:
+ - Full crossbar in a group of four (`quad_perm`)
+ - Wavefront shift left by one lane (`wave_shl`)
+ - Wavefront shift right by one lane (`wave_shr`)
+ - Wavefront rotate right by one lane (`wave_ror`)
+ - Wavefront rotate left by one lane (`wave_rol`)
+ - Row shift left by 1–15 lanes (`row_shl`)
+ - Row shift right by 1–15 lanes (`row_shr`)
+ - Row rotate right by 1–15 lanes (`row_ror`)
+ - Reverse within a row (`row_mirror`)
+ - Reverse within a half-row (`row_half_mirror`)
+ - Broadcast the 15th lane of each row to the next row (`row_bcast`)
+ - Broadcast lane 31 to rows 2 and 3 (`row_bcast`)
+ }];
+ let results = (outs AnyType:$result);
+ let assemblyFormat = [{
+ $old $src $kind (`(` $permArgument^ `)`)? attr-dict `:` type($result)
+ }];
+ let hasVerifier = 1;
+}
+
def AMDGPU_LDSBarrierOp : AMDGPU_Op<"lds_barrier"> {
let summary = "Barrier that includes a wait for LDS memory operations.";
let description = [{
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index b81b7e577b213a..a1e6fc3e299009 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -555,6 +555,30 @@ def ROCDL_RawBufferAtomicUMinOp :
let hasCustomAssemblyFormat = 1;
}
+// DPP Update intrinsic
+def ROCDL_DPPUpdateOp : ROCDL_IntrOp<"update.dpp", [], [0],
+ [AllTypesMatch<["res", "src", "old"]>], 1>,
+ Arguments<(ins LLVM_Type:$old, LLVM_Type:$src, I32Attr:$dppCtrl, I32Attr:$rowMask,
+ I32Attr:$bankMask, I1Attr:$boundCtrl)> {
+ let results = (outs LLVM_Type:$res);
+ let assemblyFormat = [{
+ attr-dict $old `,` $src `with` $dppCtrl `,` $rowMask `,` $bankMask `,` $boundCtrl `:` type($src)
+ }];
+ string llvmBuilder = [{
+ auto vdataType = moduleTranslation.convertType(op.getSrc().getType());
+ llvm::Value *args[] = {
+ moduleTranslation.lookupValue(op.getOld()),
+ moduleTranslation.lookupValue(op.getSrc()),
+ builder.getInt32(op.getDppCtrl()),
+ builder.getInt32(op.getRowMask()),
+ builder.getInt32(op.getBankMask()),
+ builder.getInt1(op.getBoundCtrl())
+ };
+ $res = createIntrinsicCall(builder,
+ llvm::Intrinsic::amdgcn_update_dpp, args, {vdataType});
+ }];
+}
+
//===---------------------------------------------------------------------===//
// 8-bit float intrinsics
//===---------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index b808738804030f..809e9448e80abf 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -844,6 +844,155 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
return success();
}
+// Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
+// operation into the corresponding ROCDL instructions.
+struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
+ AMDGPUDPPLowering(LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ // Convert the source operand to the corresponding LLVM type
+ Location loc = DppOp.getLoc();
+ Value src = adaptor.getSrc();
+ Value old = adaptor.getOld();
+ Type srcType = src.getType();
+ Type oldType = old.getType();
+ Type llvmType = nullptr;
+ if (srcType.getIntOrFloatBitWidth() < 32) {
+ llvmType = rewriter.getI32Type();
+ } else if (isa<FloatType>(srcType)) {
+ llvmType = (srcType.getIntOrFloatBitWidth() == 32)
+ ? rewriter.getF32Type()
+ : rewriter.getF64Type();
+ } else if (isa<IntegerType>(srcType)) {
+ llvmType = (srcType.getIntOrFloatBitWidth() == 32)
+ ? rewriter.getI32Type()
+ : rewriter.getI64Type();
+ }
+ auto llvmSrcIntType = typeConverter->convertType(
+ rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
+
+ // If the source type is less of 32, use bitcast to convert it to i32.
+ auto convertOperand = [&](Value operand, Type operandType) {
+ if (operandType.getIntOrFloatBitWidth() <= 16) {
+ if (llvm::isa<FloatType>(operandType)) {
+ operand =
+ rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
+ }
+ auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
+ 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
+ Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
+ operand = rewriter.create<LLVM::InsertElementOp>(
+ loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
+ operand = rewriter.create<LLVM::BitcastOp>(loc, llvmType, operand);
+ }
+ return operand;
+ };
+
+ src = convertOperand(src, srcType);
+ old = convertOperand(old, oldType);
+
+ // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
+ enum DppCtrl : unsigned {
+ ROW_SHL0 = 0x100,
+ ROW_SHR0 = 0x110,
+ ROW_ROR0 = 0x120,
+ WAVE_SHL1 = 0x130,
+ WAVE_ROL1 = 0x134,
+ WAVE_SHR1 = 0x138,
+ WAVE_ROR1 = 0x13C,
+ ROW_MIRROR = 0x140,
+ ROW_HALF_MIRROR = 0x141,
+ BCAST15 = 0x142,
+ BCAST31 = 0x143,
+ };
+
+ auto kind = DppOp.getKind();
+ auto permArgument = DppOp.getPermArgument();
+ uint32_t DppCtrl = 0;
+
+ switch (kind) {
+
+ case DPPPerm::quad_perm:
+ if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
+ int32_t i = 0;
+ for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
+ uint32_t num = elem.getInt();
+ DppCtrl |= num << (i * 2);
+ i++;
+ }
+ }
+ break;
+ case DPPPerm::row_shl:
+ if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
+ DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
+ }
+ break;
+ case DPPPerm::row_shr:
+ if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
+ DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
+ }
+ break;
+ case DPPPerm::row_ror:
+ if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
+ DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
+ }
+ break;
+ case DPPPerm::wave_shl:
+ DppCtrl = DppCtrl::WAVE_SHL1;
+ break;
+ case DPPPerm::wave_shr:
+ DppCtrl = DppCtrl::WAVE_SHR1;
+ break;
+ case DPPPerm::wave_rol:
+ DppCtrl = DppCtrl::WAVE_ROL1;
+ break;
+ case DPPPerm::wave_ror:
+ DppCtrl = DppCtrl::WAVE_ROR1;
+ break;
+ case DPPPerm::row_mirror:
+ DppCtrl = DppCtrl::ROW_MIRROR;
+ break;
+ case DPPPerm::row_half_mirror:
+ DppCtrl = DppCtrl::ROW_HALF_MIRROR;
+ break;
+ case DPPPerm::row_bcast_15:
+ DppCtrl = DppCtrl::BCAST15;
+ break;
+ case DPPPerm::row_bcast_31:
+ DppCtrl = DppCtrl::BCAST31;
+ break;
+ }
+
+ // Check for row_mask, bank_mask, bound_ctrl if they exist and create
+ // constants
+ auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
+ auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
+ bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
+
+ // create a ROCDL_DPPMovOp instruction with the appropriate attributes
+ auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>(
+ loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
+
+ Value result = dppMovOp.getRes();
+ if (srcType.getIntOrFloatBitWidth() < 32) {
+ result = rewriter.create<LLVM::TruncOp>(loc, llvmSrcIntType, result);
+ if (!llvm::isa<IntegerType>(srcType)) {
+ result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result);
+ }
+ }
+
+ // We are replacing the AMDGPU_DPPOp instruction with the new
+ // ROCDL_DPPMovOp instruction
+ rewriter.replaceOp(DppOp, ValueRange(result));
+ return success();
+ }
+};
+
struct ConvertAMDGPUToROCDLPass
: public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
ConvertAMDGPUToROCDLPass() = default;
@@ -895,9 +1044,10 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
ROCDL::RawPtrBufferAtomicUminOp>,
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
ROCDL::RawPtrBufferAtomicCmpSwap>,
- LDSBarrierOpLowering, SchedBarrierOpLowering, MFMAOpLowering,
- WMMAOpLowering, ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
- PackedStochRoundFp8OpLowering>(converter, chipset);
+ AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
+ MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
+ PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter,
+ chipset);
}
std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() {
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index e3beceaa3bbb5b..c1a785fb25478d 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -326,6 +326,68 @@ LogicalResult MFMAOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// DPPOp
+//===----------------------------------------------------------------------===//
+LogicalResult DPPOp::verify() {
+ Type srcType = getSrc().getType();
+ if (srcType.getIntOrFloatBitWidth() > 64) {
+ return emitOpError("integer and floating point types larger than 64 bits "
+ "are not supported");
+ }
+
+ DPPPerm kind = getKind();
+ Attribute permArgument = getPermArgument().value_or(Attribute{});
+
+ switch (kind) {
+
+ case DPPPerm::quad_perm: {
+ auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
+ if (!quadPermAttr || quadPermAttr.size() != 4) {
+ return emitOpError("quad_perm attribute must have exactly 4 elements");
+ }
+ for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
+ uint32_t num = elem.getInt();
+ if (num < 0 || num > 3) {
+ return emitOpError(
+ "Each element of quad_perm must be in the range [0, 3]");
+ }
+ }
+ } break;
+
+ case DPPPerm::row_shl:
+ case DPPPerm::row_shr:
+ case DPPPerm::row_ror: {
+ if (!permArgument) {
+ return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) +
+ "' value not specified");
+ }
+ if (auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
+ uint32_t attrValue = intAttr.getInt();
+ if (attrValue < 1 || attrValue > 15) {
+ return emitOpError("Attribute value must be between 1 and 15");
+ }
+ }
+ } break;
+
+ case DPPPerm::wave_shl:
+ case DPPPerm::wave_shr:
+ case DPPPerm::wave_rol:
+ case DPPPerm::wave_ror:
+ case DPPPerm::row_mirror:
+ case DPPPerm::row_half_mirror:
+ case DPPPerm::row_bcast_15:
+ case DPPPerm::row_bcast_31: {
+ if (permArgument && !isa<UnitAttr>(permArgument)) {
+ return emitOpError("Expected unit attribute for permArgument, but found "
+ "non-trivial argument");
+ }
+ break;
+ }
+ }
+ return success();
+}
+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
#define GET_ATTRDEF_CLASSES
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/dpp.mlir b/mlir/test/Conversion/AMDGPUToROCDL/dpp.mlir
new file mode 100644
index 00000000000000..14691e73e62d78
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/dpp.mlir
@@ -0,0 +1,139 @@
+// RUN: mlir-opt -convert-amdgpu-to-rocdl=chipset=gfx908 %s | FileCheck %s
+// RUN: mlir-opt -convert-amdgpu-to-rocdl=chipset=gfx90a %s | FileCheck %s
+// RUN: mlir-opt -convert-amdgpu-to-rocdl=chipset=gfx942 %s | FileCheck %s
+
+func.func @test_dpp(%arg0: i32, %arg1: i32) -> i32 {
+ // CHECK-LABEL: func @test_dpp
+ // CHECK: rocdl.update.dpp %arg0, %arg1 with 257, 10, 15, false : i32
+ // CHECK: return %0 : i32
+ %0 = amdgpu.dpp %arg0 %arg1 row_shl ( 0x1 : i32 ) { row_mask = 0xa : i32, bound_ctrl = false } : i32
+ return %0 : i32
+}
+
+func.func @quad_dpp(%arg0: i32, %arg1: i32) -> i32 {
+ // CHECK-LABEL: func @quad_dpp
+ // CHECK: rocdl.update.dpp %arg0, %arg1 with 145, 1, 1, true : i32
+ // CHECK: return %0 : i32
+ %0 = amdgpu.dpp %arg0 %arg1 quad_perm ( [1,0,1,2] ) { row_mask = 0x1 : i32, bank_mask = 0x1 : i32, bound_ctrl = true } : i32
+ return %0 : i32
+}
+
+func.func @wave_shr_dpp(%arg0: i32, %arg1: i32) -> i32 {
+ // CHECK-LABEL: func @wave_shr_dpp
+ // CHECK: rocdl.update.dpp %arg0, %arg1 with 312, 10, 1, true : i32
+ // CHECK: return %0 : i32
+ %0 = amdgpu.dpp %arg0 %arg1 wave_shr { row_mask = 0xa : i32, bank_mask = 0x1 : i32, bound_ctrl = true } : i32
+ return %0 : i32
+}
+
+func.func @row_half_mirror_update_dpp(%arg0: i32, %arg1: i32) -> i32 {
+ // CHECK-LABEL: func @row_half_mirror_update_dpp
+ // CHECK: rocdl.update.dpp %arg0, %arg1 with 321, 15, 1, false : i32
+ // CHECK: return %0 : i32
+%0 = amdgpu.dpp %arg0 %arg1 row_half_mirror { bank_mask = 0x1 : i32 } : i32
+ return %0 : i32
+}
+
+func.func @wave_rol_update_dpp(%arg0: i32, %arg1: i32) -> i32 {
+ // CHECK-LABEL: func @wave_rol_update_dpp
+ // CHECK: rocdl.update.dpp %arg0, %arg1 with 308, 10, 1, false : i32
+ // CHECK: return %0 : i32
+ %0 = amdgpu.dpp %arg0 %arg1 wave_rol { row_mask = 0xa : i32, bank_mask = 0x1 : i32 } : i32
+ return %0 : i32
+}
+
+func.func @row_bcast_dpp_f32(%arg0: f32, %arg1: f32) -> f32 {
+ // CHECK-LABEL: func @row_bcast_dpp_f32
+ // CHECK: rocdl.update.dpp %arg0, %arg1 with 322, 15, 15, true : f32
+ // CHECK: return %0 : f32
+ %0 = amdgpu.dpp %arg0 %arg1 row_bcast_15 { bound_ctrl = true } : f32
+ return %0 : f32
+}
+
+func.func @test_dpp_f32(%arg0: f32, %arg1: f32) -> f32 {
+ // CHECK-LABEL: func @test_dpp_f32
+ // CHECK: rocdl.update.dpp %arg0, %arg1 with 320, 1, 4, true : f32
+ // CHECK: return %0 : f32
+ %0 = amdgpu.dpp %arg0 %arg1 row_mirror { row_mask = 0x1 : i32, bank_mask = 0x4 : i32, bound_ctrl = true } : f32
+ return %0 : f32
+}
+
+func.func @quad_perm_update_dpp_f32(%arg0: f32, %arg1: f32) -> f32 {
+ // CHECK-LABEL: func @quad_perm_update_dpp_f32
+ // CHECK: rocdl.update.dpp %arg0, %arg1 with 88, 15, 1, false : f32
+ // CHECK: return %0 : f32
+ %0 = amdgpu.dpp %arg0 %arg1 quad_perm ( [0,2,1,1] ) { bank_mask = 0x1 : i32 } : f32
+ return %0 : f32
+}
+
+func.func @quad_perm_dpp(%arg0: i64, %arg1: i64) -> i64 {
+ // CHECK-LABEL: func @quad_perm_dpp
+ // CHECK: rocdl.update.dpp %arg0, %arg1 with 88, 15, 15, false : i64
+ // CHECK: return %0 : i64
+ %0 = amdgpu.dpp %arg0 %arg1 quad_perm ( [0,2,1,1] ) : i64
+ return %0 : i64
+}
+
+func.func @row_bcast_dpp(%arg0: f64, %arg1: f64) -> f64 {
+ // CHECK-LABEL: func @row_bcast_dpp
+ // CHECK: rocdl.update.dpp %arg0, %arg1 with 323, 4, 1, false : f64
+ // CHECK: return %0 : f64
+ %0 = amdgpu.dpp %arg0 %arg1 row_bcast_31 { row_mask = 0x4 : i32, bank_mask = 0x1 : i32} : f64
+ return %0 : f64
+}
+
+func.func @test_dpp_f16(%arg0: f16, %arg1: f16) -> f16 {
+ // CHECK-LABEL: func @test_dpp_f16
+ // CHECK: llvm.bitcast %arg1 : f16 to i16
+ // CHECK: llvm.mlir.undef : vector<2xi16>
+ // CHECK: llvm.mlir.constant(0 : i32) : i32
+ // CHECK: llvm.insertelement %0, %1[%2 : i32] : vector<2xi16>
+ // CHECK: llvm.bitcast %3 : vector<2xi16> to i32
+ // CHECK: llvm.bitcast %arg0 : f16 to i16
+ // CHECK: llvm.mlir.undef : vector<2xi16>
+ // CHECK: llvm.mlir.constant(0 : i32) : i32
+ // CHECK: llvm.insertelement %5, %6[%7 : i32] : vector<2xi16>
+ // CHECK: llvm.bitcast %8 : vector<2xi16> to i32
+ // CHECK: rocdl.update.dpp %9, %4 with 273, 15, 3, false : i32
+ // CHECK: llvm.trunc %10 : i32 to i16
+ // CHECK: llvm.bitcast %11 : i16 to f16
+ // CHECK: return %12 : f16
+ %0 = amdgpu.dpp %arg0 %arg1 row_shr ( 0x1 : i32 ){ bank_mask = 0x3 : i32 } : f16
+ return %0 : f16
+}
+
+func.func @row_shl_dpp_i16(%arg0: i16, %arg1: i16) -> i16 {
+ // CHECK-LABEL: func @row_shl_dpp_i16
+ // CHECK: llvm.mlir.undef : vector<2xi16>
+ // CHECK: llvm.mlir.constant(0 : i32) : i32
+ // CHECK: llvm.insertelement %arg1, %0[%1 : i32] : vector<2xi16>
+ // CHECK: llvm.bitcast %2 : vector<2xi16> to i32
+ // CHECK: llvm.mlir.undef : vector<2xi16>
+ // CHECK: llvm.mlir.constant(0 : i32) : i32
+ // CHECK: llvm.insertelement %arg0, %4[%5 : i32] : vector<2xi16>
+ // CHECK: llvm.bitcast %6 : vector<2xi16> to i32
+ // CHECK: rocdl.update.dpp %7, %3 with 298, 10, 1, false : i32
+ // CHECK: llvm.trunc %8 : i32 to i16
+ // CHECK: return %9 : i16
+ %0 = amdgpu.dpp %arg0 %arg1 row_ror ( 0xa : i32 ) { row_mask = 0xa : i32, bank_mask = 0x1 : i32 } : i16
+ return %0 : i16
+}
+
+func.func @row_bcast_update_dpp_f16(%arg0: f16, %arg1: f16) -> f16 {
+ // CHECK-LABEL: func @row_bcast_update_dpp_f16
+ // CHECK: llvm.bitcast %arg1 : f16 to i16
+ // CHECK: llvm.mlir.undef : vector<2xi16>
+ // CHECK: llvm.mlir.constant(0 : i32) : i32
+ // CHECK: llvm.insertelement %0, %1[%2 : i32] : vector<2xi16>
+ // CHECK: llvm.bitcast %arg0 : f16 to i16
+ // CHECK: llvm.mlir.undef : vector<2xi16>
+ // CHECK: llvm.mlir.constant(0 : i32) : i32
+ // CHECK llvm.insertelement %5, %6[%7 : i32] : vector<2xi16>
+ // CHECK: llvm.bitcast %8 : vector<2xi16> to i32
+ // CHECK: rocdl.update.dpp %9, %4 with 322, 15, 15, true : i32
+ // CHECK: llvm.trunc %10 : i32 to i16
+ // CHECK: llvm.bitcast %11 : i16 to f16
+ // CHECK: return %12 : f16
+ %0 = amdgpu.dpp %arg0 %arg1 row_bcast_15 { bound_ctrl = true } : f16
+ return %0 : f16
+}
More information about the Mlir-commits
mailing list