[Mlir-commits] [mlir] c55b41d - [mlir][AMDGPU] Define amdgpu.mfma operator
Krzysztof Drewniak
llvmlistbot at llvm.org
Wed Aug 31 14:06:22 PDT 2022
Author: Krzysztof Drewniak
Date: 2022-08-31T21:06:12Z
New Revision: c55b41d5199d2394dd6cdb8f52180d8b81d809d4
URL: https://github.com/llvm/llvm-project/commit/c55b41d5199d2394dd6cdb8f52180d8b81d809d4
DIFF: https://github.com/llvm/llvm-project/commit/c55b41d5199d2394dd6cdb8f52180d8b81d809d4.diff
LOG: [mlir][AMDGPU] Define amdgpu.mfma operator
The amdgpu.mfma operator is a wrapper around the Matrix Fused Multiply
Add (MFMA) instructions on some AMD GPUs (the CDNA-based MI-* cards).
This interface allows for selecting the operation to be performed by
specifying the dimensions of the multiplication to be performed and
any additional attributes (such as whether to use reduced-precision
floating-point math) that are needed to select the relevant mfma
instruction and set its parameters.
Reviewed By: ThomasRaoux, nirvedhmeshram
Differential Revision: https://reviews.llvm.org/D132956
Added:
mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir
mlir/test/Dialect/AMDGPU/invalid.mlir
Modified:
mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td
mlir/include/mlir/Dialect/AMDGPU/AMDGPUDialect.h
mlir/include/mlir/Dialect/AMDGPU/CMakeLists.txt
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
mlir/test/Dialect/AMDGPU/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td
index fe97f8e9e9368..040af3b9b27c3 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td
@@ -10,6 +10,7 @@
#define AMDGPU
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"
def AMDGPU_Dialect : Dialect {
@@ -23,6 +24,7 @@ def AMDGPU_Dialect : Dialect {
}];
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
+ let useDefaultAttributePrinterParser = 1;
}
//===----------------------------------------------------------------------===//
@@ -182,5 +184,92 @@ def AMDGPU_LDSBarrierOp : AMDGPU_Op<"lds_barrier"> {
let assemblyFormat = "attr-dict";
}
+def AMDGPU_MFMAPermB : I32EnumAttr<"MFMAPermB",
+ "The possible permutations of the lanes storing B available in an MFMA",
+ [
+ I32EnumAttrCase<"none", 0>,
+ I32EnumAttrCase<"bcast_first_32", 1>,
+ I32EnumAttrCase<"bcast_second_32", 2>,
+ I32EnumAttrCase<"rotate_16_right", 3>,
+ I32EnumAttrCase<"bcast_first_16", 4>,
+ I32EnumAttrCase<"bcast_second_16", 5>,
+ I32EnumAttrCase<"bcast_third_16", 6>,
+ I32EnumAttrCase<"bcast_fourth_16", 7>
+ ]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::amdgpu";
+}
+
+def AMDGPU_MFMAPermBAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_MFMAPermB,
+ "mfma_perm_b">;
+
+// mfma
+def MFMAInTypes : AnyTypeOf<[F32, F64, I32, I64,
+ VectorOfLengthAndType<[2], [F32]>,
+ VectorOfLengthAndType<[4], [F16]>,
+ VectorOfLengthAndType<[2, 4], [BF16]>,
+ VectorOfLengthAndType<[4, 8], [I8]>]>;
+def MFMAOutTypes : AnyTypeOf<[F64,
+ VectorOfLengthAndType<[4, 16, 32], [F32]>,
+ VectorOfLengthAndType<[4, 16, 32], [I32]>,
+ VectorOfLengthAndType<[4], [F64]>]>;
+
+def AMDGPU_MFMAOp :
+ AMDGPU_Op<"mfma", [AllTypesMatch<["sourceA", "sourceB"]>,
+ AllTypesMatch<["destC", "destD"]>,
+ NoSideEffect]>,
+ Arguments<(ins
+ I32Attr:$m,
+ I32Attr:$n,
+ I32Attr:$k,
+ I32Attr:$blocks,
+ MFMAInTypes:$sourceA,
+ MFMAInTypes:$sourceB,
+ MFMAOutTypes:$destC,
+ DefaultValuedAttr<I32Attr, "0">:$cbsz,
+ DefaultValuedAttr<I32Attr, "0">:$abid,
+ DefaultValuedAttr<AMDGPU_MFMAPermBAttr,
+ "::mlir::amdgpu::MFMAPermB::none">:$blgp,
+ UnitAttr:$reducePrecision,
+ UnitAttr:$negateA,
+ UnitAttr:$negateB,
+ UnitAttr:$negateC)>,
+ Results<(outs MFMAOutTypes: $destD)> {
+ let summary = "MLIR wrapper for CDNA mfma instructions";
+ let description = [{
+ The `amdgpu.mfma` op is an MLIR wrapper around intrinsics
+ for various `mfma` instructions in the CDNA architecture, which perform
+ multiple outer products in order to allow fast matrix multiplication.
+
+ The wrapper will select an appropriate `mfma` instruction, if one is available,
+ based on the provided `m`, `k`, `n`, and `nBlks` attributes, along with the
+ types of the source and destination arguments.
+
+ For information on the layouts of the input and output matrces (which are stored
+ in `sourceA`, `sourceB`, `destC`, and `destD`), see the CDNA ISA documentation.
+
+ The `cbsz`, `abid`, and `blgp` parameters control how the lanes of the wave
+ are permuted when matrix data is being loaded: `blgp` can be any number of
+ fixed permutations, `cbsz` specifies the log_2 of the number of chunks the lanes
+ holding sourceA are split into, and `abid` selects one of those chunks.
+
+ Note, this wrapper allows specifying `vector<4Kxi8>` arguments to MFMA
+ intrinsics that take an integer type of width `4K`. For example,
+ one can provide a vector<4xi8> as an argument to an MFMA instruction that
+ logically takes 4 i8s but whose intrinsics are specified to take an i32.
+ In these cases, the bytes in the vector will be concatenated in little-endian
+ order (that is, v[0] will go to arg[7:0], v[1] to arg[15:8] and so on).
+
+ The negateA, negateB, and negateC flags are only supported for double-precision
+ operations on gfx940+.
+ }];
+ let assemblyFormat = [{
+ $sourceA `*` $sourceB `+` $destC
+ attr-dict
+ `blgp` `=` $blgp
+ `:` type($sourceA) `,` type($destC)
+ }];
+ let hasVerifier = 1;
+}
#endif // AMDGPU
diff --git a/mlir/include/mlir/Dialect/AMDGPU/AMDGPUDialect.h b/mlir/include/mlir/Dialect/AMDGPU/AMDGPUDialect.h
index 92193736d3297..7a7b86695ccf2 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/AMDGPUDialect.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/AMDGPUDialect.h
@@ -21,6 +21,11 @@
#include "mlir/Dialect/AMDGPU/AMDGPUDialect.h.inc"
+#include "mlir/Dialect/AMDGPU/AMDGPUEnums.h.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/AMDGPU/AMDGPUAttributes.h.inc"
+
#define GET_OP_CLASSES
#include "mlir/Dialect/AMDGPU/AMDGPU.h.inc"
diff --git a/mlir/include/mlir/Dialect/AMDGPU/CMakeLists.txt b/mlir/include/mlir/Dialect/AMDGPU/CMakeLists.txt
index a2d56b067d721..ed074c205a551 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/AMDGPU/CMakeLists.txt
@@ -2,3 +2,11 @@ add_mlir_dialect(AMDGPU amdgpu)
add_mlir_doc(AMDGPU AMDGPU Dialects/ -gen-dialect-doc)
set(LLVM_TARGET_DEFINITIONS AMDGPU.td)
+mlir_tablegen(AMDGPUEnums.h.inc -gen-enum-decls)
+mlir_tablegen(AMDGPUEnums.cpp.inc -gen-enum-defs)
+add_public_tablegen_target(MLIRAMDGPUEnumsGen)
+
+set(LLVM_TARGET_DEFINITIONS AMDGPU.td)
+mlir_tablegen(AMDGPUAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=amdgpu)
+mlir_tablegen(AMDGPUAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=amdgpu)
+add_public_tablegen_target(MLIRAMDGPUAttributesIncGen)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 81b8f6e7a03a5..7a3bf22327834 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -11,9 +11,12 @@
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/AMDGPU/AMDGPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Pass/Pass.h"
+#include "llvm/ADT/STLExtras.h"
+
namespace mlir {
#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDL
#include "mlir/Conversion/Passes.h.inc"
@@ -25,7 +28,7 @@ using namespace mlir::amdgpu;
static Value createI32Constant(ConversionPatternRewriter &rewriter,
Location loc, int32_t value) {
Type llvmI32 = rewriter.getI32Type();
- return rewriter.create<LLVM::ConstantOp>(loc, llvmI32, value);
+ return rewriter.createOrFold<LLVM::ConstantOp>(loc, llvmI32, value);
}
namespace {
@@ -123,7 +126,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
MemRefDescriptor memrefDescriptor(memref);
Type llvmI64 = this->typeConverter->convertType(rewriter.getI64Type());
Type llvm2xI32 = this->typeConverter->convertType(VectorType::get(2, i32));
- Value c32I64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 32);
+ Value c32I64 = rewriter.create<LLVM::ConstantOp>(
+ loc, llvmI64, rewriter.getI64IntegerAttr(32));
Value resource = rewriter.create<LLVM::UndefOp>(loc, llvm4xI32);
@@ -273,6 +277,173 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
return success();
}
};
+} // namespace
+
+/// If `input` is a vector of bytes, concatentate those bytes in little-endian
+/// order to form a single integer of size 8 * [vector length]. This works
+/// around a wart in the AMDGPU intrinsics where operations that logically take
+/// vectors of bytes instead integers. Since we do not want to expose this
+/// implementation detail to MLIR, we correct for it here.
+static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter,
+ Location loc, Value input) {
+ Type inputType = input.getType();
+ if (auto vectorType = inputType.dyn_cast<VectorType>()) {
+ if (!vectorType.getElementType().isInteger(8))
+ return input;
+ int64_t numBytes = vectorType.getNumElements();
+ Type destType = rewriter.getIntegerType(numBytes * 8);
+ Value result = rewriter.createOrFold<LLVM::ConstantOp>(
+ loc, destType, rewriter.getIntegerAttr(destType, 0));
+ for (int64_t i = 0; i < numBytes; ++i) {
+ Value idxConst = createI32Constant(rewriter, loc, i);
+ Value element =
+ rewriter.create<LLVM::ExtractElementOp>(loc, input, idxConst);
+ Value extended = rewriter.create<LLVM::ZExtOp>(loc, destType, element);
+ Value shiftConst = rewriter.createOrFold<LLVM::ConstantOp>(
+ loc, destType, rewriter.getIntegerAttr(destType, i * 8));
+ Value shifted = rewriter.create<LLVM::ShlOp>(loc, extended, shiftConst);
+ result = rewriter.create<LLVM::OrOp>(loc, result, shifted);
+ }
+ return result;
+ }
+ return input;
+}
+
+/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
+/// if one exists. This includes checking to ensure the intrinsic is supported
+/// on the architecture you are compiling for.
+static Optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset) {
+ uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
+ b = mfma.getBlocks();
+ Type sourceElem = mfma.getSourceA().getType();
+ if (auto sourceType = sourceElem.dyn_cast<VectorType>())
+ sourceElem = sourceType.getElementType();
+ Type destElem = mfma.getDestC().getType();
+ if (auto destType = destElem.dyn_cast<VectorType>())
+ destElem = destType.getElementType();
+
+ if (sourceElem.isF32() && destElem.isF32()) {
+ if (mfma.getReducePrecision() && chipset.minorVersion >= 0x40) {
+ if (m == 32 && n == 32 && k == 4 && b == 1)
+ return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
+ if (m == 16 && n == 16 && k == 8 && b == 1)
+ return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
+ }
+ if (m == 32 && n == 32 && k == 1 && b == 2)
+ return ROCDL::mfma_f32_32x32x1f32::getOperationName();
+ if (m == 16 && n == 16 && k == 1 && b == 4)
+ return ROCDL::mfma_f32_16x16x1f32::getOperationName();
+ if (m == 4 && n == 4 && k == 1 && b == 16)
+ return ROCDL::mfma_f32_4x4x1f32::getOperationName();
+ if (m == 32 && n == 32 && k == 2 && b == 1)
+ return ROCDL::mfma_f32_32x32x2f32::getOperationName();
+ if (m == 16 && n == 16 && k == 4 && b == 1)
+ return ROCDL::mfma_f32_16x16x4f32::getOperationName();
+ }
+
+ if (sourceElem.isF16() && destElem.isF32()) {
+ if (m == 32 && n == 32 && k == 4 && b == 2)
+ return ROCDL::mfma_f32_32x32x4f16::getOperationName();
+ if (m == 16 && n == 16 && k == 4 && b == 4)
+ return ROCDL::mfma_f32_16x16x4f16::getOperationName();
+ if (m == 4 && n == 4 && k == 4 && b == 16)
+ return ROCDL::mfma_f32_4x4x4f16::getOperationName();
+ if (m == 32 && n == 32 && k == 8 && b == 1)
+ return ROCDL::mfma_f32_32x32x8f16::getOperationName();
+ if (m == 16 && n == 16 && k == 16 && b == 1)
+ return ROCDL::mfma_f32_16x16x16f16::getOperationName();
+ }
+
+ if (sourceElem.isBF16() && destElem.isF32() && chipset.minorVersion >= 0x0a) {
+ if (m == 32 && n == 32 && k == 4 && b == 2)
+ return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
+ if (m == 16 && n == 16 && k == 4 && b == 4)
+ return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
+ if (m == 4 && n == 4 && k == 4 && b == 16)
+ return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
+ if (m == 32 && n == 32 && k == 8 && b == 1)
+ return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
+ if (m == 16 && n == 16 && k == 16 && b == 1)
+ return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
+ }
+
+ if (sourceElem.isBF16() && destElem.isF32()) {
+ if (m == 32 && n == 32 && k == 2 && b == 2)
+ return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
+ if (m == 16 && n == 16 && k == 2 && b == 4)
+ return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
+ if (m == 4 && n == 4 && k == 2 && b == 16)
+ return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
+ if (m == 32 && n == 32 && k == 4 && b == 1)
+ return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
+ if (m == 16 && n == 16 && k == 8 && b == 1)
+ return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
+ }
+
+ if (sourceElem.isa<IntegerType>() && destElem.isInteger(32)) {
+ if (m == 32 && n == 32 && k == 4 && b == 2)
+ return ROCDL::mfma_i32_32x32x4i8::getOperationName();
+ if (m == 16 && n == 16 && k == 4 && b == 4)
+ return ROCDL::mfma_i32_16x16x4i8::getOperationName();
+ if (m == 4 && n == 4 && k == 4 && b == 16)
+ return ROCDL::mfma_i32_4x4x4i8::getOperationName();
+ if (m == 32 && n == 32 && k == 8 && b == 1)
+ return ROCDL::mfma_i32_32x32x8i8::getOperationName();
+ if (m == 16 && n == 16 && k == 16 && b == 1)
+ return ROCDL::mfma_i32_16x16x16i8::getOperationName();
+ if (m == 32 && n == 32 && k == 16 && b == 1 && chipset.minorVersion >= 0x40)
+ return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
+ if (m == 16 && n == 16 && k == 32 && b == 1 && chipset.minorVersion >= 0x40)
+ return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
+ }
+
+ if (sourceElem.isF64() && destElem.isF64() && chipset.minorVersion >= 0x0a) {
+ if (m == 16 && n == 16 && k == 4 && b == 1)
+ return ROCDL::mfma_f64_16x16x4f64::getOperationName();
+ if (m == 4 && n == 4 && k == 4 && b == 4)
+ return ROCDL::mfma_f64_4x4x4f64::getOperationName();
+ }
+ return None;
+}
+
+namespace {
+struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
+ MFMAOpLowering(LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Type outType = typeConverter->convertType(op.getDestD().getType());
+
+ if (chipset.majorVersion != 9 || chipset.minorVersion < 0x08)
+ return op->emitOpError("MFMA only supported on gfx908+");
+ uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
+ if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
+ if (chipset.minorVersion < 0x40)
+ return op.emitOpError("negation unsupported on older than gfx840");
+ getBlgpField |=
+ op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
+ }
+ Optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
+ if (!maybeIntrinsic.has_value())
+ return op.emitOpError("no intrinsic matching MFMA size on given chipset");
+ OperationState loweredOp(loc, *maybeIntrinsic);
+ loweredOp.addTypes(outType);
+ loweredOp.addOperands(
+ {mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceA()),
+ mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceB()),
+ adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()),
+ createI32Constant(rewriter, loc, op.getAbid()),
+ createI32Constant(rewriter, loc, getBlgpField)});
+ Operation *lowered = rewriter.create(loweredOp);
+ rewriter.replaceOp(op, lowered->getResults());
+ return success();
+ }
+};
struct ConvertAMDGPUToROCDLPass
: public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
@@ -290,6 +461,7 @@ struct ConvertAMDGPUToROCDLPass
LLVMTypeConverter converter(ctx);
populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
LLVMConversionTarget target(getContext());
+ target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
if (failed(applyPartialConversion(getOperation(), target,
@@ -306,8 +478,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
patterns.add<
RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawBufferLoadOp>,
RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawBufferStoreOp>,
- RawBufferOpLowering<RawBufferAtomicFaddOp, ROCDL::RawBufferAtomicFAddOp>>(
- converter, chipset);
+ RawBufferOpLowering<RawBufferAtomicFaddOp, ROCDL::RawBufferAtomicFAddOp>,
+ MFMAOpLowering>(converter, chipset);
}
std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() {
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 68639da21428b..05b0e621cf0f9 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -11,19 +11,29 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/AMDGPU/AMDGPUDialect.h"
+
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
+using namespace mlir::amdgpu;
#include "mlir/Dialect/AMDGPU/AMDGPUDialect.cpp.inc"
-void amdgpu::AMDGPUDialect::initialize() {
+void AMDGPUDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/AMDGPU/AMDGPU.cpp.inc"
>();
+ addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/AMDGPU/AMDGPUAttributes.cpp.inc"
+ >();
}
//===----------------------------------------------------------------------===//
@@ -44,17 +54,78 @@ static LogicalResult verifyRawBufferOp(T &op) {
return success();
}
-LogicalResult amdgpu::RawBufferLoadOp::verify() {
- return verifyRawBufferOp(*this);
-}
+LogicalResult RawBufferLoadOp::verify() { return verifyRawBufferOp(*this); }
+
+LogicalResult RawBufferStoreOp::verify() { return verifyRawBufferOp(*this); }
-LogicalResult amdgpu::RawBufferStoreOp::verify() {
+LogicalResult RawBufferAtomicFaddOp::verify() {
return verifyRawBufferOp(*this);
}
-LogicalResult amdgpu::RawBufferAtomicFaddOp::verify() {
- return verifyRawBufferOp(*this);
+//===----------------------------------------------------------------------===//
+// MFMAOp
+//===----------------------------------------------------------------------===//
+LogicalResult MFMAOp::verify() {
+ constexpr uint32_t waveSize = 64;
+ Builder b(getContext());
+
+ Type sourceType = getSourceA().getType();
+ Type destType = getDestC().getType();
+
+ Type sourceElem = sourceType, destElem = destType;
+ uint32_t sourceLen = 1, destLen = 1;
+ if (auto sourceVector = sourceType.dyn_cast<VectorType>()) {
+ sourceLen = sourceVector.getNumElements();
+ sourceElem = sourceVector.getElementType();
+ }
+ if (auto destVector = destType.dyn_cast<VectorType>()) {
+ destLen = destVector.getNumElements();
+ destElem = destVector.getElementType();
+ }
+
+ // Normalize the wider integer types the compiler expects to i8
+ if (sourceElem.isInteger(32)) {
+ sourceLen *= 4;
+ sourceElem = b.getI8Type();
+ }
+ if (sourceElem.isInteger(64)) {
+ sourceLen *= 8;
+ sourceElem = b.getI8Type();
+ }
+
+ int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize;
+ if (sourceLen != numSourceElems)
+ return emitOpError("expected " + Twine(numSourceElems) +
+ " source values for this operation but got " +
+ Twine(sourceLen));
+
+ int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize;
+ if (destLen != numDestElems)
+ return emitOpError("expected " + Twine(numDestElems) +
+ " result values for this operation but got " +
+ Twine(destLen));
+
+ if (destElem.isF64() && getBlgp() != MFMAPermB::none)
+ return emitOpError(
+ "double-precision ops do not support permuting lanes of B");
+ if (destElem.isF64() && getCbsz() != 0)
+ return emitOpError(
+ "double-precision ops do not support permuting lanes of A");
+ if (getAbid() >= (1 << getCbsz()))
+ return emitOpError(
+ "block ID for permuting A (abid) must be below 2 ** cbsz");
+
+ if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
+ return emitOpError(
+ "negation flags only available for double-precision operations");
+
+ return success();
}
+#include "mlir/Dialect/AMDGPU/AMDGPUEnums.cpp.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/AMDGPU/AMDGPUAttributes.cpp.inc"
+
#define GET_OP_CLASSES
#include "mlir/Dialect/AMDGPU/AMDGPU.cpp.inc"
diff --git a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
index 60834c45babf1..1b80265baa90b 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
@@ -5,6 +5,8 @@ add_mlir_dialect_library(MLIRAMDGPUDialect
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU
DEPENDS
+ MLIRAMDGPUEnumsGen
+ MLIRAMDGPUAttributesIncGen
MLIRAMDGPUIncGen
LINK_LIBS PUBLIC
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir
new file mode 100644
index 0000000000000..9117cd7b2126c
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx940 | FileCheck %s
+func.func @mfma_to_rocdl(%arg0 : f32, %arg1 : vector<32xf32>,
+ %arg2 : vector<16xf32>, %arg3 : vector<4xf32>,
+ %arg4 : vector<4xf16>, %arg5 : vector<4xi8>,
+ %arg6 : vector<32xi32>, %arg7 : vector<16xi32>,
+ %arg8 : vector<4xi32>, %arg9 : vector<2xbf16>,
+ %arg10 : vector<4xbf16>, %arg11 : f64,
+ %arg12 : vector<4xf64>, %arg13 : vector<8xi8>,
+ %arg14 : vector<2xf32>) {
+ // CHECK: rocdl.mfma.f32.32x32x1f32{{.*}}: (f32, f32, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
+ amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 1 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : f32, vector<32xf32>
+ // CHECK: rocdl.mfma.f32.16x16x1f32{{.*}}: (f32, f32, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.mfma %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 1 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : f32, vector<16xf32>
+ // CHECK: rocdl.mfma.f32.4x4x1f32{{.*}}: (f32, f32, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.mfma %arg0 * %arg0 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 1 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 } blgp = none : f32, vector<4xf32>
+ // CHECK: rocdl.mfma.f32.32x32x2f32{{.*}}: (f32, f32, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.mfma %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : f32, vector<16xf32>
+ // CHECK: rocdl.mfma.f32.16x16x4f32{{.*}}: (f32, f32, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.mfma %arg0 * %arg0 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : f32, vector<4xf32>
+ // CHECK: rocdl.mfma.f32.32x32x4f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
+ amdgpu.mfma %arg4 * %arg4 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : vector<4xf16>, vector<32xf32>
+ // CHECK: rocdl.mfma.f32.16x16x4f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.mfma %arg4 * %arg4 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : vector<4xf16>, vector<16xf32>
+ // CHECK: rocdl.mfma.f32.4x4x4f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.mfma %arg4 * %arg4 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 } blgp = none : vector<4xf16>, vector<4xf32>
+ // CHECK: rocdl.mfma.f32.32x32x8f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.mfma %arg4 * %arg4 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<4xf16>, vector<16xf32>
+ // CHECK: rocdl.mfma.f32.16x16x16f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.mfma %arg4 * %arg4 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<4xf16>, vector<4xf32>
+ // CHECK: rocdl.mfma.i32.32x32x4i8{{.*}}: (i32, i32, vector<32xi32>, i32, i32, i32) -> vector<32xi32>
+ amdgpu.mfma %arg5 * %arg5 + %arg6 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : vector<4xi8>, vector<32xi32>
+ // CHECK: rocdl.mfma.i32.16x16x4i8{{.*}}: (i32, i32, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
+ amdgpu.mfma %arg5 * %arg5 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : vector<4xi8>, vector<16xi32>
+ // CHECK: rocdl.mfma.i32.4x4x4i8{{.*}}: (i32, i32, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+ amdgpu.mfma %arg5 * %arg5 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 } blgp = none : vector<4xi8>, vector<4xi32>
+ // CHECK: rocdl.mfma.i32.32x32x8i8{{.*}}: (i32, i32, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
+ amdgpu.mfma %arg5 * %arg5 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<4xi8>, vector<16xi32>
+ // CHECK: rocdl.mfma.i32.16x16x16i8{{.*}}: (i32, i32, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+ amdgpu.mfma %arg5 * %arg5 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<4xi8>, vector<4xi32>
+ // CHECK: rocdl.mfma.f32.32x32x2bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
+ amdgpu.mfma %arg9 * %arg9 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : vector<2xbf16>, vector<32xf32>
+ // CHECK: rocdl.mfma.f32.16x16x2bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : vector<2xbf16>, vector<16xf32>
+ // CHECK: rocdl.mfma.f32.4x4x2bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.mfma %arg9 * %arg9 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 } blgp = none : vector<2xbf16>, vector<4xf32>
+ // CHECK: rocdl.mfma.f32.32x32x4bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<2xbf16>, vector<16xf32>
+ // CHECK: rocdl.mfma.f32.16x16x8bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.mfma %arg9 * %arg9 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<2xbf16>, vector<4xf32>
+ // CHECK: rocdl.mfma.f32.32x32x4bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
+ amdgpu.mfma %arg10 * %arg10 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : vector<4xbf16>, vector<32xf32>
+ // CHECK: rocdl.mfma.f32.16x16x4bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.mfma %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : vector<4xbf16>, vector<16xf32>
+ // CHECK: rocdl.mfma.f32.4x4x4bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.mfma %arg10 * %arg10 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 } blgp = none : vector<4xbf16>, vector<4xf32>
+ // CHECK: rocdl.mfma.f32.32x32x8bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.mfma %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<4xbf16>, vector<16xf32>
+ // CHECK: rocdl.mfma.f32.16x16x16bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.mfma %arg10 * %arg10 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<4xbf16>, vector<4xf32>
+ // CHECK: rocdl.mfma.f64.16x16x4f64{{.*}}: (f64, f64, vector<4xf64>, i32, i32, i32) -> vector<4xf64>
+ amdgpu.mfma %arg11 * %arg11 + %arg12 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : f64, vector<4xf64>
+ // CHECK: rocdl.mfma.f64.4x4x4f64{{.*}}: (f64, f64, f64, i32, i32, i32) -> f64
+ amdgpu.mfma %arg11 * %arg11 + %arg11 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 4 : i32 } blgp = none : f64, f64
+ // CHECK: rocdl.mfma.i32.16x16x32.i8{{.*}}: (i64, i64, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+ amdgpu.mfma %arg13 * %arg13 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xi8>, vector<4xi32>
+ // CHECK: rocdl.mfma.i32.32x32x16.i8{{.*}}: (i64, i64, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
+ amdgpu.mfma %arg13 * %arg13 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<8xi8>, vector<16xi32>
+ // CHECK: rocdl.mfma.f32.16x16x8.xf32{{.*}}: (vector<2xf32>, vector<2xf32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.mfma %arg14 * %arg14 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32, reducePrecision } blgp = none : vector<2xf32>, vector<4xf32>
+ // CHECK: rocdl.mfma.f32.32x32x4.xf32{{.*}}: (vector<2xf32>, vector<2xf32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.mfma %arg14 * %arg14 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32, reducePrecision } blgp = none : vector<2xf32>, vector<16xf32>
+ func.return
+}
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
new file mode 100644
index 0000000000000..9ac8038655dd6
--- /dev/null
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -0,0 +1,83 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+// -----
+
+func.func @bad_source_arguments(%a: vector<2xf32>, %b: vector<2xf32>,
+ %c: vector<32xf32>) -> vector<32xf32> {
+ // expected-error at +1 {{'amdgpu.mfma' op expected 1 source values for this operation but got 2}}
+ %d = amdgpu.mfma %a * %b + %c {
+ m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32,
+ abid = 0 : i32, cbsz = 0 : i32} blgp = none : vector<2xf32>, vector<32xf32>
+ func.return %d : vector<32xf32>
+}
+
+// -----
+
+func.func @bad_source_arguments_i8(%a: vector<8xi8>, %b: vector<8xi8>,
+ %c: vector<4xi32>) -> vector<4xi32> {
+ // expected-error at +1 {{'amdgpu.mfma' op expected 4 source values for this operation but got 8}}
+ %d = amdgpu.mfma %a * %b + %c {
+ m = 32 : i32, n = 32 : i32, k = 4 : i32, blocks = 2 : i32,
+ abid = 0 : i32, cbsz = 0 : i32} blgp = none : vector<8xi8>, vector<4xi32>
+ func.return %d : vector<4xi32>
+}
+
+// -----
+
+func.func @bad_dest_type(%a: f32, %b: f32, %c: vector<16xf32>) -> vector<16xf32> {
+ // expected-error at +1 {{'amdgpu.mfma' op expected 32 result values for this operation but got 16}}
+ %d = amdgpu.mfma %a * %b + %c {
+ m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32,
+ abid = 0 : i32, cbsz = 0 : i32} blgp = none : f32, vector<16xf32>
+ return %d : vector<16xf32>
+}
+
+// -----
+
+func.func @f64_permuting_b(%a: f64, %b: f64, %c: vector<4xf64>) -> vector<4xf64> {
+ // expected-error at +1 {{'amdgpu.mfma' op double-precision ops do not support permuting lanes of B}}
+ %d = amdgpu.mfma %a * %b + %c {
+ m = 16 : i32, n = 16 : i32, k = 4 : i32, blocks = 1 : i32,
+ abid = 0 : i32, cbsz = 0 : i32} blgp = bcast_first_32 : f64, vector<4xf64>
+ return %d : vector<4xf64>
+}
+
+// -----
+
+func.func @f64_permuting_a(%a: f64, %b: f64, %c: vector<4xf64>) -> vector<4xf64> {
+ // expected-error at +1 {{'amdgpu.mfma' op double-precision ops do not support permuting lanes of A}}
+ %d = amdgpu.mfma %a * %b + %c {
+ m = 16 : i32, n = 16 : i32, k = 4 : i32, blocks = 1 : i32,
+ abid = 0 : i32, cbsz = 1 : i32} blgp = none : f64, vector<4xf64>
+ return %d : vector<4xf64>
+}
+
+// -----
+
+func.func @abid_without_bradcast(%a: f32, %b: f32, %c: vector<32xf32>) -> vector<32xf32> {
+ // expected-error at +1 {{'amdgpu.mfma' op block ID for permuting A (abid) must be below 2 ** cbsz}}
+ %d = amdgpu.mfma %a * %b + %c {
+ m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32,
+ abid = 1 : i32, cbsz = 0 : i32} blgp = none : f32, vector<32xf32>
+ func.return %d : vector<32xf32>
+}
+
+// -----
+
+func.func @abid_too_large(%a: f32, %b: f32, %c: vector<32xf32>) -> vector<32xf32> {
+ // expected-error at +1 {{'amdgpu.mfma' op block ID for permuting A (abid) must be below 2 ** cbsz}}
+ %d = amdgpu.mfma %a * %b + %c {
+ m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32,
+ abid = 2 : i32, cbsz = 1 : i32} blgp = none : f32, vector<32xf32>
+ func.return %d : vector<32xf32>
+}
+
+// -----
+
+func.func @no_negation(%a: f32, %b: f32, %c: vector<32xf32>) -> vector<32xf32> {
+ // expected-error at +1 {{'amdgpu.mfma' op negation flags only available for double-precision operations}}
+ %d = amdgpu.mfma %a * %b + %c {
+ m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32,
+ abid = 0 : i32, cbsz = 0 : i32, negateA} blgp = none : f32, vector<32xf32>
+ func.return %d : vector<32xf32>
+}
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 3fff10c666ba2..40a88e25c4dd8 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -66,3 +66,10 @@ func.func @lds_barrier() {
amdgpu.lds_barrier
func.return
}
+
+// CHECK-LABEL: func @mfma
+func.func @mfma(%arg0 : f32, %arg1 : vector<32xf32>) -> vector<32xf32> {
+ // CHECK: amdgpu.mfma
+ %0 = amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 1 : i32, cbsz = 1 : i32, k = 1 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = bcast_second_32 : f32, vector<32xf32>
+ func.return %0 : vector<32xf32>
+}
More information about the Mlir-commits
mailing list