[Mlir-commits] [mlir] c55b41d - [mlir][AMDGPU] Define amdgpu.mfma operator

Krzysztof Drewniak llvmlistbot at llvm.org
Wed Aug 31 14:06:22 PDT 2022


Author: Krzysztof Drewniak
Date: 2022-08-31T21:06:12Z
New Revision: c55b41d5199d2394dd6cdb8f52180d8b81d809d4

URL: https://github.com/llvm/llvm-project/commit/c55b41d5199d2394dd6cdb8f52180d8b81d809d4
DIFF: https://github.com/llvm/llvm-project/commit/c55b41d5199d2394dd6cdb8f52180d8b81d809d4.diff

LOG: [mlir][AMDGPU] Define amdgpu.mfma operator

The amdgpu.mfma operator is a wrapper around the Matrix Fused Multiply
Add (MFMA) instructions on some AMD GPUs (the CDNA-based MI-* cards).

This interface allows for selecting the operation to be performed by
specifying the dimensions of the multiplication to be performed and
any additional attributes (such as whether to use reduced-precision
floating-point math) that are needed to select the relevant mfma
instruction and set its parameters.

Reviewed By: ThomasRaoux, nirvedhmeshram

Differential Revision: https://reviews.llvm.org/D132956

Added: 
    mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir
    mlir/test/Dialect/AMDGPU/invalid.mlir

Modified: 
    mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td
    mlir/include/mlir/Dialect/AMDGPU/AMDGPUDialect.h
    mlir/include/mlir/Dialect/AMDGPU/CMakeLists.txt
    mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
    mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
    mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
    mlir/test/Dialect/AMDGPU/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td
index fe97f8e9e9368..040af3b9b27c3 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td
@@ -10,6 +10,7 @@
 #define AMDGPU
 
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpBase.td"
 
 def AMDGPU_Dialect : Dialect {
@@ -23,6 +24,7 @@ def AMDGPU_Dialect : Dialect {
   }];
 
   let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
+  let useDefaultAttributePrinterParser = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -182,5 +184,92 @@ def AMDGPU_LDSBarrierOp : AMDGPU_Op<"lds_barrier"> {
   let assemblyFormat = "attr-dict";
 }
 
+def AMDGPU_MFMAPermB : I32EnumAttr<"MFMAPermB",
+    "The possible permutations of the lanes storing B available in an MFMA",
+    [
+      I32EnumAttrCase<"none",            0>,
+      I32EnumAttrCase<"bcast_first_32",  1>,
+      I32EnumAttrCase<"bcast_second_32", 2>,
+      I32EnumAttrCase<"rotate_16_right", 3>,
+      I32EnumAttrCase<"bcast_first_16",  4>,
+      I32EnumAttrCase<"bcast_second_16", 5>,
+      I32EnumAttrCase<"bcast_third_16",  6>,
+      I32EnumAttrCase<"bcast_fourth_16", 7>
+    ]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::amdgpu";
+}
+
+def AMDGPU_MFMAPermBAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_MFMAPermB,
+  "mfma_perm_b">;
+
+// mfma
+def MFMAInTypes : AnyTypeOf<[F32, F64, I32, I64,
+                             VectorOfLengthAndType<[2], [F32]>,
+                             VectorOfLengthAndType<[4], [F16]>,
+                             VectorOfLengthAndType<[2, 4], [BF16]>,
+                             VectorOfLengthAndType<[4, 8], [I8]>]>;
+def MFMAOutTypes : AnyTypeOf<[F64,
+                              VectorOfLengthAndType<[4, 16, 32], [F32]>,
+                              VectorOfLengthAndType<[4, 16, 32], [I32]>,
+                              VectorOfLengthAndType<[4], [F64]>]>;
+
+def AMDGPU_MFMAOp :
+    AMDGPU_Op<"mfma", [AllTypesMatch<["sourceA", "sourceB"]>,
+                        AllTypesMatch<["destC", "destD"]>,
+                        NoSideEffect]>,
+    Arguments<(ins
+                   I32Attr:$m,
+                   I32Attr:$n,
+                   I32Attr:$k,
+                   I32Attr:$blocks,
+                   MFMAInTypes:$sourceA,
+                   MFMAInTypes:$sourceB,
+                   MFMAOutTypes:$destC,
+                   DefaultValuedAttr<I32Attr, "0">:$cbsz,
+                   DefaultValuedAttr<I32Attr, "0">:$abid,
+                   DefaultValuedAttr<AMDGPU_MFMAPermBAttr,
+                    "::mlir::amdgpu::MFMAPermB::none">:$blgp,
+                   UnitAttr:$reducePrecision,
+                   UnitAttr:$negateA,
+                   UnitAttr:$negateB,
+                   UnitAttr:$negateC)>,
+    Results<(outs MFMAOutTypes: $destD)> {
+  let summary = "MLIR wrapper for CDNA mfma instructions";
+  let description = [{
+    The `amdgpu.mfma` op is an MLIR wrapper around intrinsics
+    for various `mfma` instructions in the CDNA architecture, which perform
+    multiple outer products in order to allow fast matrix multiplication.
+
+    The wrapper will select an appropriate `mfma` instruction, if one is available,
+    based on the provided `m`, `k`, `n`, and `nBlks` attributes, along with the
+    types of the source and destination arguments.
+
+    For information on the layouts of the input and output matrces (which are stored
+    in `sourceA`, `sourceB`, `destC`, and `destD`), see the CDNA ISA documentation.
+
+    The `cbsz`, `abid`, and `blgp` parameters control how the lanes of the wave
+    are permuted when matrix data is being loaded: `blgp` can be any number of
+    fixed permutations, `cbsz` specifies the log_2 of the number of chunks the lanes
+    holding sourceA are split into, and `abid` selects one of those chunks.
+
+    Note, this wrapper allows specifying `vector<4Kxi8>` arguments to MFMA
+    intrinsics that take an integer type of width `4K`. For example,
+    one can provide a vector<4xi8> as an argument to an MFMA instruction that
+    logically takes 4 i8s but whose intrinsics are specified to take an i32.
+    In these cases, the bytes in the vector will be concatenated in little-endian
+    order (that is, v[0] will go to arg[7:0], v[1] to arg[15:8] and so on).
+
+    The negateA, negateB, and negateC flags are only supported for double-precision
+    operations on gfx940+.
+  }];
+  let assemblyFormat = [{
+    $sourceA `*` $sourceB `+` $destC
+    attr-dict
+    `blgp` `=` $blgp
+    `:` type($sourceA) `,` type($destC)
+  }];
+  let hasVerifier = 1;
+}
 
 #endif // AMDGPU

diff  --git a/mlir/include/mlir/Dialect/AMDGPU/AMDGPUDialect.h b/mlir/include/mlir/Dialect/AMDGPU/AMDGPUDialect.h
index 92193736d3297..7a7b86695ccf2 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/AMDGPUDialect.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/AMDGPUDialect.h
@@ -21,6 +21,11 @@
 
 #include "mlir/Dialect/AMDGPU/AMDGPUDialect.h.inc"
 
+#include "mlir/Dialect/AMDGPU/AMDGPUEnums.h.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/AMDGPU/AMDGPUAttributes.h.inc"
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/AMDGPU/AMDGPU.h.inc"
 

diff  --git a/mlir/include/mlir/Dialect/AMDGPU/CMakeLists.txt b/mlir/include/mlir/Dialect/AMDGPU/CMakeLists.txt
index a2d56b067d721..ed074c205a551 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/AMDGPU/CMakeLists.txt
@@ -2,3 +2,11 @@ add_mlir_dialect(AMDGPU amdgpu)
 add_mlir_doc(AMDGPU AMDGPU Dialects/ -gen-dialect-doc)
 
 set(LLVM_TARGET_DEFINITIONS AMDGPU.td)
+mlir_tablegen(AMDGPUEnums.h.inc -gen-enum-decls)
+mlir_tablegen(AMDGPUEnums.cpp.inc -gen-enum-defs)
+add_public_tablegen_target(MLIRAMDGPUEnumsGen)
+
+set(LLVM_TARGET_DEFINITIONS AMDGPU.td)
+mlir_tablegen(AMDGPUAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=amdgpu)
+mlir_tablegen(AMDGPUAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=amdgpu)
+add_public_tablegen_target(MLIRAMDGPUAttributesIncGen)

diff  --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 81b8f6e7a03a5..7a3bf22327834 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -11,9 +11,12 @@
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Dialect/AMDGPU/AMDGPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
 #include "mlir/Pass/Pass.h"
 
+#include "llvm/ADT/STLExtras.h"
+
 namespace mlir {
 #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDL
 #include "mlir/Conversion/Passes.h.inc"
@@ -25,7 +28,7 @@ using namespace mlir::amdgpu;
 static Value createI32Constant(ConversionPatternRewriter &rewriter,
                                Location loc, int32_t value) {
   Type llvmI32 = rewriter.getI32Type();
-  return rewriter.create<LLVM::ConstantOp>(loc, llvmI32, value);
+  return rewriter.createOrFold<LLVM::ConstantOp>(loc, llvmI32, value);
 }
 
 namespace {
@@ -123,7 +126,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
     MemRefDescriptor memrefDescriptor(memref);
     Type llvmI64 = this->typeConverter->convertType(rewriter.getI64Type());
     Type llvm2xI32 = this->typeConverter->convertType(VectorType::get(2, i32));
-    Value c32I64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 32);
+    Value c32I64 = rewriter.create<LLVM::ConstantOp>(
+        loc, llvmI64, rewriter.getI64IntegerAttr(32));
 
     Value resource = rewriter.create<LLVM::UndefOp>(loc, llvm4xI32);
 
@@ -273,6 +277,173 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
     return success();
   }
 };
+} // namespace
+
+/// If `input` is a vector of bytes, concatentate those bytes in little-endian
+/// order to form a single integer of size 8 * [vector length]. This works
+/// around a wart in the AMDGPU intrinsics where operations that logically take
+/// vectors of bytes instead integers. Since we do not want to expose this
+/// implementation detail to MLIR, we correct for it here.
+static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter,
+                                Location loc, Value input) {
+  Type inputType = input.getType();
+  if (auto vectorType = inputType.dyn_cast<VectorType>()) {
+    if (!vectorType.getElementType().isInteger(8))
+      return input;
+    int64_t numBytes = vectorType.getNumElements();
+    Type destType = rewriter.getIntegerType(numBytes * 8);
+    Value result = rewriter.createOrFold<LLVM::ConstantOp>(
+        loc, destType, rewriter.getIntegerAttr(destType, 0));
+    for (int64_t i = 0; i < numBytes; ++i) {
+      Value idxConst = createI32Constant(rewriter, loc, i);
+      Value element =
+          rewriter.create<LLVM::ExtractElementOp>(loc, input, idxConst);
+      Value extended = rewriter.create<LLVM::ZExtOp>(loc, destType, element);
+      Value shiftConst = rewriter.createOrFold<LLVM::ConstantOp>(
+          loc, destType, rewriter.getIntegerAttr(destType, i * 8));
+      Value shifted = rewriter.create<LLVM::ShlOp>(loc, extended, shiftConst);
+      result = rewriter.create<LLVM::OrOp>(loc, result, shifted);
+    }
+    return result;
+  }
+  return input;
+}
+
+/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
+/// if one exists. This includes checking to ensure the intrinsic is supported
+/// on the architecture you are compiling for.
+static Optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset) {
+  uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
+           b = mfma.getBlocks();
+  Type sourceElem = mfma.getSourceA().getType();
+  if (auto sourceType = sourceElem.dyn_cast<VectorType>())
+    sourceElem = sourceType.getElementType();
+  Type destElem = mfma.getDestC().getType();
+  if (auto destType = destElem.dyn_cast<VectorType>())
+    destElem = destType.getElementType();
+
+  if (sourceElem.isF32() && destElem.isF32()) {
+    if (mfma.getReducePrecision() && chipset.minorVersion >= 0x40) {
+      if (m == 32 && n == 32 && k == 4 && b == 1)
+        return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
+      if (m == 16 && n == 16 && k == 8 && b == 1)
+        return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
+    }
+    if (m == 32 && n == 32 && k == 1 && b == 2)
+      return ROCDL::mfma_f32_32x32x1f32::getOperationName();
+    if (m == 16 && n == 16 && k == 1 && b == 4)
+      return ROCDL::mfma_f32_16x16x1f32::getOperationName();
+    if (m == 4 && n == 4 && k == 1 && b == 16)
+      return ROCDL::mfma_f32_4x4x1f32::getOperationName();
+    if (m == 32 && n == 32 && k == 2 && b == 1)
+      return ROCDL::mfma_f32_32x32x2f32::getOperationName();
+    if (m == 16 && n == 16 && k == 4 && b == 1)
+      return ROCDL::mfma_f32_16x16x4f32::getOperationName();
+  }
+
+  if (sourceElem.isF16() && destElem.isF32()) {
+    if (m == 32 && n == 32 && k == 4 && b == 2)
+      return ROCDL::mfma_f32_32x32x4f16::getOperationName();
+    if (m == 16 && n == 16 && k == 4 && b == 4)
+      return ROCDL::mfma_f32_16x16x4f16::getOperationName();
+    if (m == 4 && n == 4 && k == 4 && b == 16)
+      return ROCDL::mfma_f32_4x4x4f16::getOperationName();
+    if (m == 32 && n == 32 && k == 8 && b == 1)
+      return ROCDL::mfma_f32_32x32x8f16::getOperationName();
+    if (m == 16 && n == 16 && k == 16 && b == 1)
+      return ROCDL::mfma_f32_16x16x16f16::getOperationName();
+  }
+
+  if (sourceElem.isBF16() && destElem.isF32() && chipset.minorVersion >= 0x0a) {
+    if (m == 32 && n == 32 && k == 4 && b == 2)
+      return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
+    if (m == 16 && n == 16 && k == 4 && b == 4)
+      return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
+    if (m == 4 && n == 4 && k == 4 && b == 16)
+      return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
+    if (m == 32 && n == 32 && k == 8 && b == 1)
+      return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
+    if (m == 16 && n == 16 && k == 16 && b == 1)
+      return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
+  }
+
+  if (sourceElem.isBF16() && destElem.isF32()) {
+    if (m == 32 && n == 32 && k == 2 && b == 2)
+      return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
+    if (m == 16 && n == 16 && k == 2 && b == 4)
+      return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
+    if (m == 4 && n == 4 && k == 2 && b == 16)
+      return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
+    if (m == 32 && n == 32 && k == 4 && b == 1)
+      return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
+    if (m == 16 && n == 16 && k == 8 && b == 1)
+      return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
+  }
+
+  if (sourceElem.isa<IntegerType>() && destElem.isInteger(32)) {
+    if (m == 32 && n == 32 && k == 4 && b == 2)
+      return ROCDL::mfma_i32_32x32x4i8::getOperationName();
+    if (m == 16 && n == 16 && k == 4 && b == 4)
+      return ROCDL::mfma_i32_16x16x4i8::getOperationName();
+    if (m == 4 && n == 4 && k == 4 && b == 16)
+      return ROCDL::mfma_i32_4x4x4i8::getOperationName();
+    if (m == 32 && n == 32 && k == 8 && b == 1)
+      return ROCDL::mfma_i32_32x32x8i8::getOperationName();
+    if (m == 16 && n == 16 && k == 16 && b == 1)
+      return ROCDL::mfma_i32_16x16x16i8::getOperationName();
+    if (m == 32 && n == 32 && k == 16 && b == 1 && chipset.minorVersion >= 0x40)
+      return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
+    if (m == 16 && n == 16 && k == 32 && b == 1 && chipset.minorVersion >= 0x40)
+      return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
+  }
+
+  if (sourceElem.isF64() && destElem.isF64() && chipset.minorVersion >= 0x0a) {
+    if (m == 16 && n == 16 && k == 4 && b == 1)
+      return ROCDL::mfma_f64_16x16x4f64::getOperationName();
+    if (m == 4 && n == 4 && k == 4 && b == 4)
+      return ROCDL::mfma_f64_4x4x4f64::getOperationName();
+  }
+  return None;
+}
+
+namespace {
+struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
+  MFMAOpLowering(LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Type outType = typeConverter->convertType(op.getDestD().getType());
+
+    if (chipset.majorVersion != 9 || chipset.minorVersion < 0x08)
+      return op->emitOpError("MFMA only supported on gfx908+");
+    uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
+    if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
+      if (chipset.minorVersion < 0x40)
+        return op.emitOpError("negation unsupported on older than gfx840");
+      getBlgpField |=
+          op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
+    }
+    Optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
+    if (!maybeIntrinsic.has_value())
+      return op.emitOpError("no intrinsic matching MFMA size on given chipset");
+    OperationState loweredOp(loc, *maybeIntrinsic);
+    loweredOp.addTypes(outType);
+    loweredOp.addOperands(
+        {mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceA()),
+         mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceB()),
+         adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()),
+         createI32Constant(rewriter, loc, op.getAbid()),
+         createI32Constant(rewriter, loc, getBlgpField)});
+    Operation *lowered = rewriter.create(loweredOp);
+    rewriter.replaceOp(op, lowered->getResults());
+    return success();
+  }
+};
 
 struct ConvertAMDGPUToROCDLPass
     : public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
@@ -290,6 +461,7 @@ struct ConvertAMDGPUToROCDLPass
     LLVMTypeConverter converter(ctx);
     populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
     LLVMConversionTarget target(getContext());
+    target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
     target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
     target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
     if (failed(applyPartialConversion(getOperation(), target,
@@ -306,8 +478,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
   patterns.add<
       RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawBufferLoadOp>,
       RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawBufferStoreOp>,
-      RawBufferOpLowering<RawBufferAtomicFaddOp, ROCDL::RawBufferAtomicFAddOp>>(
-      converter, chipset);
+      RawBufferOpLowering<RawBufferAtomicFaddOp, ROCDL::RawBufferAtomicFAddOp>,
+      MFMAOpLowering>(converter, chipset);
 }
 
 std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() {

diff  --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 68639da21428b..05b0e621cf0f9 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -11,19 +11,29 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/AMDGPU/AMDGPUDialect.h"
+
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
+using namespace mlir::amdgpu;
 
 #include "mlir/Dialect/AMDGPU/AMDGPUDialect.cpp.inc"
 
-void amdgpu::AMDGPUDialect::initialize() {
+void AMDGPUDialect::initialize() {
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/AMDGPU/AMDGPU.cpp.inc"
       >();
+  addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/AMDGPU/AMDGPUAttributes.cpp.inc"
+      >();
 }
 
 //===----------------------------------------------------------------------===//
@@ -44,17 +54,78 @@ static LogicalResult verifyRawBufferOp(T &op) {
   return success();
 }
 
-LogicalResult amdgpu::RawBufferLoadOp::verify() {
-  return verifyRawBufferOp(*this);
-}
+LogicalResult RawBufferLoadOp::verify() { return verifyRawBufferOp(*this); }
+
+LogicalResult RawBufferStoreOp::verify() { return verifyRawBufferOp(*this); }
 
-LogicalResult amdgpu::RawBufferStoreOp::verify() {
+LogicalResult RawBufferAtomicFaddOp::verify() {
   return verifyRawBufferOp(*this);
 }
 
-LogicalResult amdgpu::RawBufferAtomicFaddOp::verify() {
-  return verifyRawBufferOp(*this);
+//===----------------------------------------------------------------------===//
+// MFMAOp
+//===----------------------------------------------------------------------===//
+LogicalResult MFMAOp::verify() {
+  constexpr uint32_t waveSize = 64;
+  Builder b(getContext());
+
+  Type sourceType = getSourceA().getType();
+  Type destType = getDestC().getType();
+
+  Type sourceElem = sourceType, destElem = destType;
+  uint32_t sourceLen = 1, destLen = 1;
+  if (auto sourceVector = sourceType.dyn_cast<VectorType>()) {
+    sourceLen = sourceVector.getNumElements();
+    sourceElem = sourceVector.getElementType();
+  }
+  if (auto destVector = destType.dyn_cast<VectorType>()) {
+    destLen = destVector.getNumElements();
+    destElem = destVector.getElementType();
+  }
+
+  // Normalize the wider integer types the compiler expects to i8
+  if (sourceElem.isInteger(32)) {
+    sourceLen *= 4;
+    sourceElem = b.getI8Type();
+  }
+  if (sourceElem.isInteger(64)) {
+    sourceLen *= 8;
+    sourceElem = b.getI8Type();
+  }
+
+  int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize;
+  if (sourceLen != numSourceElems)
+    return emitOpError("expected " + Twine(numSourceElems) +
+                       " source values for this operation but got " +
+                       Twine(sourceLen));
+
+  int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize;
+  if (destLen != numDestElems)
+    return emitOpError("expected " + Twine(numDestElems) +
+                       " result values for this operation but got " +
+                       Twine(destLen));
+
+  if (destElem.isF64() && getBlgp() != MFMAPermB::none)
+    return emitOpError(
+        "double-precision ops do not support permuting lanes of B");
+  if (destElem.isF64() && getCbsz() != 0)
+    return emitOpError(
+        "double-precision ops do not support permuting lanes of A");
+  if (getAbid() >= (1 << getCbsz()))
+    return emitOpError(
+        "block ID for permuting A (abid) must be below 2 ** cbsz");
+
+  if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
+    return emitOpError(
+        "negation flags only available for double-precision operations");
+
+  return success();
 }
 
+#include "mlir/Dialect/AMDGPU/AMDGPUEnums.cpp.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/AMDGPU/AMDGPUAttributes.cpp.inc"
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/AMDGPU/AMDGPU.cpp.inc"

diff  --git a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
index 60834c45babf1..1b80265baa90b 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
@@ -5,6 +5,8 @@ add_mlir_dialect_library(MLIRAMDGPUDialect
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU
 
   DEPENDS
+  MLIRAMDGPUEnumsGen
+  MLIRAMDGPUAttributesIncGen
   MLIRAMDGPUIncGen
 
   LINK_LIBS PUBLIC

diff  --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir
new file mode 100644
index 0000000000000..9117cd7b2126c
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx940 | FileCheck %s
+func.func @mfma_to_rocdl(%arg0 : f32, %arg1 : vector<32xf32>,
+                    %arg2 : vector<16xf32>, %arg3 : vector<4xf32>,
+                    %arg4 : vector<4xf16>, %arg5 : vector<4xi8>,
+                    %arg6 : vector<32xi32>, %arg7 : vector<16xi32>,
+                    %arg8 : vector<4xi32>, %arg9 : vector<2xbf16>,
+                    %arg10 : vector<4xbf16>, %arg11 : f64,
+                    %arg12 : vector<4xf64>, %arg13 : vector<8xi8>,
+                    %arg14 : vector<2xf32>) {
+  // CHECK: rocdl.mfma.f32.32x32x1f32{{.*}}: (f32, f32, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
+  amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 1 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 }  blgp = none : f32, vector<32xf32>
+  // CHECK: rocdl.mfma.f32.16x16x1f32{{.*}}: (f32, f32, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.mfma %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 1 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 }  blgp = none : f32, vector<16xf32>
+  // CHECK: rocdl.mfma.f32.4x4x1f32{{.*}}: (f32, f32, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.mfma %arg0 * %arg0 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 1 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 }  blgp = none : f32, vector<4xf32>
+  // CHECK: rocdl.mfma.f32.32x32x2f32{{.*}}: (f32, f32, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.mfma %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : f32, vector<16xf32>
+  // CHECK: rocdl.mfma.f32.16x16x4f32{{.*}}: (f32, f32, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.mfma %arg0 * %arg0 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : f32, vector<4xf32>
+  // CHECK: rocdl.mfma.f32.32x32x4f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
+  amdgpu.mfma %arg4 * %arg4 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 }  blgp = none : vector<4xf16>, vector<32xf32>
+  // CHECK: rocdl.mfma.f32.16x16x4f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.mfma %arg4 * %arg4 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 }  blgp = none : vector<4xf16>, vector<16xf32>
+  // CHECK: rocdl.mfma.f32.4x4x4f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.mfma %arg4 * %arg4 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 }  blgp = none : vector<4xf16>, vector<4xf32>
+  // CHECK: rocdl.mfma.f32.32x32x8f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.mfma %arg4 * %arg4 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<4xf16>, vector<16xf32>
+  // CHECK: rocdl.mfma.f32.16x16x16f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.mfma %arg4 * %arg4 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<4xf16>, vector<4xf32>
+  // CHECK: rocdl.mfma.i32.32x32x4i8{{.*}}: (i32, i32, vector<32xi32>, i32, i32, i32) -> vector<32xi32>
+  amdgpu.mfma %arg5 * %arg5 + %arg6 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 }  blgp = none : vector<4xi8>, vector<32xi32>
+  // CHECK: rocdl.mfma.i32.16x16x4i8{{.*}}: (i32, i32, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
+  amdgpu.mfma %arg5 * %arg5 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 }  blgp = none : vector<4xi8>, vector<16xi32>
+  // CHECK: rocdl.mfma.i32.4x4x4i8{{.*}}: (i32, i32, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+  amdgpu.mfma %arg5 * %arg5 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 }  blgp = none : vector<4xi8>, vector<4xi32>
+  // CHECK: rocdl.mfma.i32.32x32x8i8{{.*}}: (i32, i32, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
+  amdgpu.mfma %arg5 * %arg5 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<4xi8>, vector<16xi32>
+  // CHECK: rocdl.mfma.i32.16x16x16i8{{.*}}: (i32, i32, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+  amdgpu.mfma %arg5 * %arg5 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<4xi8>, vector<4xi32>
+  // CHECK: rocdl.mfma.f32.32x32x2bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
+  amdgpu.mfma %arg9 * %arg9 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 }  blgp = none : vector<2xbf16>, vector<32xf32>
+  // CHECK: rocdl.mfma.f32.16x16x2bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 }  blgp = none : vector<2xbf16>, vector<16xf32>
+  // CHECK: rocdl.mfma.f32.4x4x2bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.mfma %arg9 * %arg9 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 }  blgp = none : vector<2xbf16>, vector<4xf32>
+  // CHECK: rocdl.mfma.f32.32x32x4bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<2xbf16>, vector<16xf32>
+  // CHECK: rocdl.mfma.f32.16x16x8bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.mfma %arg9 * %arg9 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<2xbf16>, vector<4xf32>
+  // CHECK: rocdl.mfma.f32.32x32x4bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
+  amdgpu.mfma %arg10 * %arg10 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 }  blgp = none : vector<4xbf16>, vector<32xf32>
+  // CHECK: rocdl.mfma.f32.16x16x4bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.mfma %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 }  blgp = none : vector<4xbf16>, vector<16xf32>
+  // CHECK: rocdl.mfma.f32.4x4x4bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.mfma %arg10 * %arg10 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 }  blgp = none : vector<4xbf16>, vector<4xf32>
+  // CHECK: rocdl.mfma.f32.32x32x8bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.mfma %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<4xbf16>, vector<16xf32>
+  // CHECK: rocdl.mfma.f32.16x16x16bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.mfma %arg10 * %arg10 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<4xbf16>, vector<4xf32>
+  // CHECK: rocdl.mfma.f64.16x16x4f64{{.*}}: (f64, f64, vector<4xf64>, i32, i32, i32) -> vector<4xf64>
+  amdgpu.mfma %arg11 * %arg11 + %arg12 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : f64, vector<4xf64>
+  // CHECK: rocdl.mfma.f64.4x4x4f64{{.*}}: (f64, f64, f64, i32, i32, i32) -> f64
+  amdgpu.mfma %arg11 * %arg11 + %arg11 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 4 : i32 }  blgp = none : f64, f64
+  // CHECK: rocdl.mfma.i32.16x16x32.i8{{.*}}: (i64, i64, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+  amdgpu.mfma %arg13 * %arg13 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<8xi8>, vector<4xi32>
+  // CHECK: rocdl.mfma.i32.32x32x16.i8{{.*}}: (i64, i64, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
+  amdgpu.mfma %arg13 * %arg13 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<8xi8>, vector<16xi32>
+  // CHECK: rocdl.mfma.f32.16x16x8.xf32{{.*}}: (vector<2xf32>, vector<2xf32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.mfma %arg14 * %arg14 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32, reducePrecision }  blgp = none : vector<2xf32>, vector<4xf32>
+  // CHECK: rocdl.mfma.f32.32x32x4.xf32{{.*}}: (vector<2xf32>, vector<2xf32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.mfma %arg14 * %arg14 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32, reducePrecision }  blgp = none : vector<2xf32>, vector<16xf32>
+  func.return
+}

diff  --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
new file mode 100644
index 0000000000000..9ac8038655dd6
--- /dev/null
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -0,0 +1,83 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+// -----
+
+func.func @bad_source_arguments(%a: vector<2xf32>, %b: vector<2xf32>,
+                                %c: vector<32xf32>) -> vector<32xf32> {
+  // expected-error at +1 {{'amdgpu.mfma' op expected 1 source values for this operation but got 2}}
+  %d = amdgpu.mfma %a * %b + %c {
+    m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32,
+    abid = 0 : i32, cbsz = 0 : i32} blgp = none : vector<2xf32>, vector<32xf32>
+  func.return %d : vector<32xf32>
+}
+
+// -----
+
+func.func @bad_source_arguments_i8(%a: vector<8xi8>, %b: vector<8xi8>,
+                                   %c: vector<4xi32>) -> vector<4xi32> {
+  // expected-error at +1 {{'amdgpu.mfma' op expected 4 source values for this operation but got 8}}
+  %d = amdgpu.mfma %a * %b + %c {
+    m = 32 : i32, n = 32 : i32, k = 4 : i32, blocks = 2 : i32,
+    abid = 0 : i32, cbsz = 0 : i32} blgp = none : vector<8xi8>, vector<4xi32>
+  func.return %d : vector<4xi32>
+}
+
+// -----
+
+func.func @bad_dest_type(%a: f32, %b: f32, %c: vector<16xf32>) -> vector<16xf32> {
+  // expected-error at +1 {{'amdgpu.mfma' op expected 32 result values for this operation but got 16}}
+  %d = amdgpu.mfma %a * %b + %c {
+    m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32,
+    abid = 0 : i32, cbsz = 0 : i32} blgp = none : f32, vector<16xf32>
+  return %d : vector<16xf32>
+}
+
+// -----
+
+func.func @f64_permuting_b(%a: f64, %b: f64, %c: vector<4xf64>) -> vector<4xf64> {
+  // expected-error at +1 {{'amdgpu.mfma' op double-precision ops do not support permuting lanes of B}}
+  %d = amdgpu.mfma %a * %b + %c {
+    m = 16 : i32, n = 16 : i32, k = 4 : i32, blocks = 1 : i32,
+    abid = 0 : i32, cbsz = 0 : i32} blgp = bcast_first_32 : f64, vector<4xf64>
+  return %d : vector<4xf64>
+}
+
+// -----
+
+func.func @f64_permuting_a(%a: f64, %b: f64, %c: vector<4xf64>) -> vector<4xf64> {
+  // expected-error at +1 {{'amdgpu.mfma' op double-precision ops do not support permuting lanes of A}}
+  %d = amdgpu.mfma %a * %b + %c {
+    m = 16 : i32, n = 16 : i32, k = 4 : i32, blocks = 1 : i32,
+    abid = 0 : i32, cbsz = 1 : i32} blgp = none : f64, vector<4xf64>
+  return %d : vector<4xf64>
+}
+
+// -----
+
+func.func @abid_without_bradcast(%a: f32, %b: f32, %c: vector<32xf32>) -> vector<32xf32> {
+  // expected-error at +1 {{'amdgpu.mfma' op block ID for permuting A (abid) must be below 2 ** cbsz}}
+  %d = amdgpu.mfma %a * %b + %c {
+    m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32,
+    abid = 1 : i32, cbsz = 0 : i32} blgp = none : f32, vector<32xf32>
+  func.return %d : vector<32xf32>
+}
+
+// -----
+
+func.func @abid_too_large(%a: f32, %b: f32, %c: vector<32xf32>) -> vector<32xf32> {
+  // expected-error at +1 {{'amdgpu.mfma' op block ID for permuting A (abid) must be below 2 ** cbsz}}
+  %d = amdgpu.mfma %a * %b + %c {
+    m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32,
+    abid = 2 : i32, cbsz = 1 : i32} blgp = none : f32, vector<32xf32>
+  func.return %d : vector<32xf32>
+}
+
+// -----
+
+func.func @no_negation(%a: f32, %b: f32, %c: vector<32xf32>) -> vector<32xf32> {
+  // expected-error at +1 {{'amdgpu.mfma' op negation flags only available for double-precision operations}}
+  %d = amdgpu.mfma %a * %b + %c {
+    m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32,
+    abid = 0 : i32, cbsz = 0 : i32, negateA} blgp = none : f32, vector<32xf32>
+  func.return %d : vector<32xf32>
+}

diff  --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 3fff10c666ba2..40a88e25c4dd8 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -66,3 +66,10 @@ func.func @lds_barrier() {
   amdgpu.lds_barrier
   func.return
 }
+
+// CHECK-LABEL: func @mfma
+func.func @mfma(%arg0 : f32, %arg1 : vector<32xf32>) -> vector<32xf32> {
+  // CHECK: amdgpu.mfma
+  %0 = amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 1 : i32, cbsz = 1 : i32, k = 1 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = bcast_second_32 : f32, vector<32xf32>
+  func.return %0 : vector<32xf32>
+}


        


More information about the Mlir-commits mailing list