[Mlir-commits] [mlir] 4b3eaee - [mlir][AMDGPU] Define wrappers for WMMA matrix ops
Krzysztof Drewniak
llvmlistbot at llvm.org
Thu Jul 20 11:38:41 PDT 2023
Author: Giuseppe Rossini
Date: 2023-07-20T18:38:35Z
New Revision: 4b3eaee2701afa2549c9cd0f78692598e9bfd44e
URL: https://github.com/llvm/llvm-project/commit/4b3eaee2701afa2549c9cd0f78692598e9bfd44e
DIFF: https://github.com/llvm/llvm-project/commit/4b3eaee2701afa2549c9cd0f78692598e9bfd44e.diff
LOG: [mlir][AMDGPU] Define wrappers for WMMA matrix ops
Wave Matrix Multiply Accumulate (WMMA) is the instruction to accelerate
matrix multiplication on RDNA3 architectures. LLVM already provides a
set of intrinsics to generate wmma instructions. This change uses those
intrinsics to enable the feature in MLIR.
Reviewed By: krzysz00
Differential Revision: https://reviews.llvm.org/D152451
Added:
mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
Modified:
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
mlir/test/Dialect/AMDGPU/invalid.mlir
mlir/test/Dialect/AMDGPU/ops.mlir
mlir/test/Target/LLVMIR/rocdl.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 5607fae575500a3..b41fc54e2c078f1 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -380,6 +380,10 @@ def MFMAOutTypes : AnyTypeOf<[F64,
VectorOfLengthAndType<[4, 16, 32], [F32]>,
VectorOfLengthAndType<[4, 16, 32], [I32]>,
VectorOfLengthAndType<[4], [F64]>]>;
+// wmma
+def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[16], [F16, BF16, I8, SI8, UI8]>]>;
+def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
+ VectorOfLengthAndType<[8, 16], [F16, BF16]>]>;
def AMDGPU_MFMAOp :
AMDGPU_Op<"mfma", [AllTypesMatch<["destC", "destD"]>,
@@ -438,4 +442,41 @@ def AMDGPU_MFMAOp :
let hasVerifier = 1;
}
+def AMDGPU_WMMAOp :
+ AMDGPU_Op<"wmma", [AllTypesMatch<["destC", "destD"]>,
+ AllTypesMatch<["sourceA", "sourceB"]>,
+ Pure]>,
+ Arguments<(ins
+ WMMAInTypes:$sourceA,
+ WMMAInTypes:$sourceB,
+ WMMAOutTypes:$destC,
+ DefaultValuedAttr<ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<1>]>, "0">:$subwordOffset,
+ UnitAttr:$unsignedA,
+ UnitAttr:$unsignedB,
+ UnitAttr:$clamp)>,
+ Results<(outs WMMAOutTypes: $destD)> {
+ let summary = "MLIR wrapper for RDNA3 wmma instructions";
+ let description = [{
+ The `amdgpu.wmma` op is an MLIR wrapper around intrinsics
+ for various `wmma` instructions in the RDNA3 architecture, which perform
+ a 16x16 matrix multiplication for
diff erent data types.
+
+ When emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16 (or 16xbf16) vector
+ containing only 8 valid values:
+ - If `subwordOffset` is 0, then the output is stored at indices 0, 2, 4, ..., 14.
+ - If `subwordOffset` is 1, then the output is stored at indices 1, 3, 5, ..., 15.
+
+ `unsignedA` and `unsignedB` flag that the `int8` LLVM inputs are unsigned.
+
+ The `clamp` flag is used to saturate the output of type T to numeric_limits<T>::max()
+ in case of overflow.
+ }];
+ let assemblyFormat = [{
+ $sourceA `*` $sourceB `+` $destC
+ attr-dict
+ `:` type($sourceA) `,` type($sourceB) `,` type($destC)
+ }];
+ let hasVerifier = 1;
+}
+
#endif // AMDGPU
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 3187ce1615ff93f..1de2031a64c0307 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -124,6 +124,7 @@ def ROCDL_BarrierOp : ROCDL_Op<"barrier"> {
let assemblyFormat = "attr-dict";
}
+
//===---------------------------------------------------------------------===//
// Xdlops intrinsics
@@ -182,6 +183,26 @@ def ROCDL_mfma_f32_32x32x16_bf8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf8.f
def ROCDL_mfma_f32_32x32x16_fp8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.fp8.bf8">;
def ROCDL_mfma_f32_32x32x16_fp8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.fp8.fp8">;
+//===---------------------------------------------------------------------===//
+// WMMA intrinsics
+class ROCDL_Wmma_IntrOp<string mnemonic, list<Trait> traits = []> :
+ LLVM_IntrOpBase<ROCDL_Dialect, mnemonic,
+ "amdgcn_" # !subst(".","_", mnemonic),
+ [0], [], traits, 1>,
+ Arguments<(ins Variadic<LLVM_Type>:$args)> {
+ let assemblyFormat =
+ "$args attr-dict `:` functional-type($args, $res)";
+}
+
+// Available on RDNA3
+def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.f16">;
+def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf16">;
+def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x16.f16">;
+def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x16.bf16">;
+def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu8">;
+def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4">;
+
+
//===---------------------------------------------------------------------===//
// Vector buffer load/store intrinsics
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 0fc1aaef0c7db41..1de583c17809329 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -32,6 +32,12 @@ static Value createI32Constant(ConversionPatternRewriter &rewriter,
return rewriter.create<LLVM::ConstantOp>(loc, llvmI32, value);
}
+static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
+ bool value) {
+ Type llvmI1 = rewriter.getI1Type();
+ return rewriter.createOrFold<LLVM::ConstantOp>(loc, llvmI1, value);
+}
+
namespace {
/// Define lowering patterns for raw buffer ops
template <typename GpuOp, typename Intrinsic>
@@ -334,6 +340,64 @@ static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter,
return input;
}
+/// Push an input operand. If it is a float type, nothing to do. If it is
+/// an integer type, then we need to also push its signdness (1 for signed, 0
+/// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
+/// vector.
+static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
+ Location loc, TypeConverter *typeConverter,
+ bool isUnsigned, Value llvmInput,
+ SmallVector<Value, 4> &operands) {
+ Type inputType = llvmInput.getType();
+ auto vectorType = inputType.dyn_cast<VectorType>();
+ Type elemType = vectorType.getElementType();
+
+ if (!elemType.isInteger(8)) {
+ operands.push_back(llvmInput);
+ return;
+ }
+
+ int64_t numBytes = vectorType.getNumElements();
+ Type i32 = rewriter.getI32Type();
+ VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32);
+ auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits);
+
+ Value result = rewriter.createOrFold<LLVM::BitcastOp>(
+ loc, llvmVectorType32bits, llvmInput);
+
+ // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
+ bool localIsUnsigned = isUnsigned;
+ if (elemType.isUnsignedInteger(8)) {
+ localIsUnsigned = true;
+ } else if (elemType.isSignedInteger(8)) {
+ localIsUnsigned = false;
+ }
+ Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
+ operands.push_back(sign);
+ operands.push_back(result);
+}
+
+/// Push the output operand. For many cases this is only pushing the output in
+/// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
+/// since the same numbers of VGPRs is used, we need to decide if to store the
+/// result in the upper 16 bits of the VGPRs or in the lower part. To store the
+/// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
+/// be stored it in the upper part
+static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
+ Location loc, TypeConverter *typeConverter,
+ Value output, int32_t subwordOffset,
+ bool clamp, SmallVector<Value, 4> &operands) {
+ Type inputType = output.getType();
+ auto vectorType = inputType.dyn_cast<VectorType>();
+ Type elemType = vectorType.getElementType();
+ operands.push_back(output);
+ if (elemType.isF16() || elemType.isBF16()) {
+ operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
+ } else if (elemType.isInteger(32)) {
+ operands.push_back(createI1Constant(rewriter, loc, clamp));
+ }
+}
+
/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
/// if one exists. This includes checking to ensure the intrinsic is supported
/// on the architecture you are compiling for.
@@ -471,6 +535,31 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
return std::nullopt;
}
+/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
+/// if one exists. This includes checking to ensure the intrinsic is supported
+/// on the architecture you are compiling for.
+static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
+ Chipset chipset) {
+
+ auto sourceVectorType = wmma.getSourceA().getType().dyn_cast<VectorType>();
+ auto destVectorType = wmma.getDestC().getType().dyn_cast<VectorType>();
+ auto elemSourceType = sourceVectorType.getElementType();
+ auto elemDestType = destVectorType.getElementType();
+
+ if (elemSourceType.isF16() && elemDestType.isF32()) {
+ return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
+ } else if (elemSourceType.isBF16() && elemDestType.isF32()) {
+ return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
+ } else if (elemSourceType.isF16() && elemDestType.isF16()) {
+ return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
+ } else if (elemSourceType.isBF16() && elemDestType.isBF16()) {
+ return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
+ } else if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) {
+ return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
+ }
+ return std::nullopt;
+}
+
namespace {
struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
MFMAOpLowering(LLVMTypeConverter &converter, Chipset chipset)
@@ -510,6 +599,45 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
}
};
+struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
+ WMMAOpLowering(LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Type outType = typeConverter->convertType(op.getDestD().getType());
+
+ if (chipset.majorVersion != 11)
+ return op->emitOpError("WMMA only supported on gfx11");
+
+ std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
+
+ if (!maybeIntrinsic.has_value())
+ return op.emitOpError("no intrinsic matching WMMA on the given chipset");
+
+ OperationState loweredOp(loc, *maybeIntrinsic);
+ loweredOp.addTypes(outType);
+
+ SmallVector<Value, 4> operands;
+ wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
+ adaptor.getSourceA(), operands);
+ wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
+ adaptor.getSourceB(), operands);
+ wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
+ op.getSubwordOffset(), op.getClamp(), operands);
+
+ loweredOp.addOperands(operands);
+ Operation *lowered = rewriter.create(loweredOp);
+ rewriter.replaceOp(op, lowered->getResults());
+
+ return success();
+ }
+};
+
struct ConvertAMDGPUToROCDLPass
: public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
ConvertAMDGPUToROCDLPass() = default;
@@ -549,7 +677,7 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
RawBufferOpLowering<RawBufferAtomicUminOp, ROCDL::RawBufferAtomicUMinOp>,
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
ROCDL::RawBufferAtomicCmpSwap>,
- MFMAOpLowering>(converter, chipset);
+ MFMAOpLowering, WMMAOpLowering>(converter, chipset);
}
std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() {
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 105535b05de859c..ac34acc8307485c 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -205,6 +205,34 @@ void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
context);
}
+//===----------------------------------------------------------------------===//
+// WMMAOp
+//===----------------------------------------------------------------------===//
+LogicalResult WMMAOp::verify() {
+ Type sourceAType = getSourceA().getType();
+ Type destType = getDestC().getType();
+
+ VectorType sourceVectorAType = sourceAType.dyn_cast<VectorType>();
+ VectorType destVectorType = destType.dyn_cast<VectorType>();
+
+ Type sourceAElemType = sourceVectorAType.getElementType();
+ Type destElemType = destVectorType.getElementType();
+
+ bool isDestFloat =
+ (destElemType.isF32() || destElemType.isF16() || destElemType.isBF16());
+ bool isSrcFloat = (sourceAElemType.isF16() || sourceAElemType.isBF16());
+
+ if (isDestFloat && !isSrcFloat) {
+ return emitOpError("Expected float sources with float destination");
+ }
+
+ if (!isDestFloat && isSrcFloat) {
+ return emitOpError("Expected int sources with int destination");
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// MFMAOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
new file mode 100644
index 000000000000000..407519605d42405
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s
+func.func @mfma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : vector<4xf32>,
+ %arg3 : vector<16xbf16>, %arg4 : vector<8xf16>, %arg5 : vector<8xbf16>,
+ %arg6 : vector<16xi8>, %arg7 : vector<4xi32>, %arg8 : vector<8xi32>,
+ %arg9 : vector<16xui8>){
+ // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
+ amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32>
+ // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32>
+ amdgpu.wmma %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32>
+ // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xbf16>, vector<16xbf16>, vector<8xf32>) -> vector<8xf32>
+ amdgpu.wmma %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
+ // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xbf16>, vector<16xbf16>, vector<4xf32>) -> vector<4xf32>
+ amdgpu.wmma %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32>
+ // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
+ amdgpu.wmma %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
+ // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16>
+ amdgpu.wmma %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
+ // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xbf16>, vector<16xbf16>, vector<16xbf16>, i1) -> vector<16xbf16>
+ amdgpu.wmma %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
+ // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xbf16>, vector<16xbf16>, vector<8xbf16>, i1) -> vector<8xbf16>
+ amdgpu.wmma %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
+ // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32>
+ amdgpu.wmma %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<4xi32>
+ // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32>
+ amdgpu.wmma %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<8xi32>
+ func.return
+}
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 82d7af2c6dfa4f9..142224e59a95d7a 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -103,3 +103,11 @@ func.func @no_negation(%a: f32, %b: f32, %c: vector<32xf32>) -> vector<32xf32> {
abid = 0 : i32, cbsz = 0 : i32, negateA} blgp = none : f32, f32, vector<32xf32>
func.return %d : vector<32xf32>
}
+
+// -----
+
+func.func @wmma(%arg0 : vector<16xf16>, %arg1 : vector<8xi32>) -> vector<8xi32> {
+ // expected-error at +1 {{'amdgpu.wmma' op Expected int sources with int destination}}
+ %0 = amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xi32>
+ func.return %0 : vector<8xi32>
+}
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index a612fba56a7634c..4088c6750c91b8d 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -94,3 +94,10 @@ func.func @mfma(%arg0 : f32, %arg1 : vector<32xf32>) -> vector<32xf32> {
%0 = amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 1 : i32, cbsz = 1 : i32, k = 1 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = bcast_second_32 : f32, f32, vector<32xf32>
func.return %0 : vector<32xf32>
}
+
+// CHECK-LABEL: func @wmma
+func.func @wmma(%arg0 : vector<16xf16>, %arg1 : vector<8xf16>) -> vector<8xf16> {
+ // CHECK: amdgpu.wmma
+ %0 = amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf16>
+ func.return %0 : vector<8xf16>
+}
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index e3f942e8bd785c7..03de137f2252233 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -215,6 +215,66 @@ llvm.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
llvm.return %r0 : vector<32 x f32>
}
+llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : vector<16 x i16>, %arg3 : vector<8 x i32>,
+ %arg4 : vector<2xi32>, %arg5 : vector<4xi32>, %arg6 : vector<4xf32>, %arg7 : vector<8xf16>, %arg8 : vector<8xi16>) -> vector<8xf32> {
+ %zero = llvm.mlir.constant(false) : i1
+
+ // ---- Wave32 -----
+
+ // f16 -> f32
+ // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16.v8f32(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <8 x float> %{{.*}})
+ %r0 = rocdl.wmma.f32.16x16x16.f16 %arg1, %arg1, %arg0 : (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
+
+ // bf16 -> f32
+ // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16.v8f32(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <8 x float> %{{.*}})
+ %r1 = rocdl.wmma.f32.16x16x16.bf16 %arg2, %arg2, %arg0 : (vector<16xi16>, vector<16xi16>, vector<8xf32>) -> vector<8xf32>
+
+ // f16 -> f16 (OPSEL = {0,1})
+ // CHECK: call <16 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16.v16f16(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <16 x half> %{{.*}}, i1 {{.*}})
+ %r2 = rocdl.wmma.f16.16x16x16.f16 %arg1, %arg1, %arg1, %zero : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
+
+ // bf16 -> bf16 (OPSEL = {0,1})
+ // CHECK: call <16 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16.v16i16(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <16 x i16> %{{.*}}, i1 {{.*}})
+ %r4 = rocdl.wmma.bf16.16x16x16.bf16 %arg2, %arg2, %arg2, %zero : (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16>
+
+ // int8 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1})
+ // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v8i32(i1 {{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}, <4 x i32> %{{.*}}, <8 x i32> %{{.*}}, i1 {{.*}})
+ %r5 = rocdl.wmma.i32.16x16x16.iu8 %zero, %arg5, %zero, %arg5, %arg3, %zero : (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32>
+
+ // int4 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1})
+ // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v8i32(i1 {{.*}}, <2 x i32> %{{.*}}, i1 {{.*}}, <2 x i32> %{{.*}}, <8 x i32> %{{.*}}, i1 {{.*}})
+ %r6 = rocdl.wmma.i32.16x16x16.iu4 %zero, %arg4, %zero, %arg4, %arg3, %zero : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
+
+ // ---- Wave64 -----
+
+ // f16 -> f32
+ // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16.v4f32(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <4 x float> %{{.*}})
+ %r7 = rocdl.wmma.f32.16x16x16.f16 %arg1, %arg1, %arg6 : (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32>
+
+ // bf16 -> f32
+ // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16.v4f32(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <4 x float> %{{.*}})
+ %r8 = rocdl.wmma.f32.16x16x16.bf16 %arg2, %arg2, %arg6 : (vector<16xi16>, vector<16xi16>, vector<4xf32>) -> vector<4xf32>
+
+ // f16 -> f16 (OPSEL = {0,1})
+ // CHECK: call <8 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16.v8f16(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <8 x half> %{{.*}}, i1 {{.*}})
+ %r9 = rocdl.wmma.f16.16x16x16.f16 %arg1, %arg1, %arg7, %zero : (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16>
+
+ // bf16 -> bf16 (OPSEL = {0,1})
+ // CHECK: call <8 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16.v8i16(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <8 x i16> %{{.*}}, i1 {{.*}})
+ %r11 = rocdl.wmma.bf16.16x16x16.bf16 %arg2, %arg2, %arg8, %zero : (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16>
+
+ // int8 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1})
+ // CHECK: call <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v4i32(i1 {{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i1 {{.*}})
+ %r12 = rocdl.wmma.i32.16x16x16.iu8 %zero, %arg5, %zero, %arg5, %arg5, %zero : (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32>
+
+ // int4 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1})
+ // CHECK: call <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v4i32(i1 {{.*}}, <2 x i32> %{{.*}}, i1 {{.*}}, <2 x i32> %{{.*}}, <4 x i32> %{{.*}}, i1 {{.*}})
+ %r13 = rocdl.wmma.i32.16x16x16.iu4 %zero, %arg4, %zero, %arg4, %arg5, %zero : (i1, vector<2xi32>, i1, vector<2xi32>, vector<4xi32>, i1) -> vector<4xi32>
+
+ llvm.return %r0 : vector<8xf32>
+}
+
+
llvm.func @rocdl.mubuf(%rsrc : vector<4xi32>, %vindex : i32,
%offset : i32, %vdata1 : vector<1xf32>,
%vdata2 : vector<2xf32>, %vdata4 : vector<4xf32>) {
More information about the Mlir-commits
mailing list