[Mlir-commits] [mlir] 22f0c7a - [mlir][AMDGPU] 8-bit float usage in the AMDGPU dialect

Krzysztof Drewniak llvmlistbot at llvm.org
Wed Feb 15 08:46:13 PST 2023


Author: Krzysztof Drewniak
Date: 2023-02-15T16:46:08Z
New Revision: 22f0c7a45149a55e643665c4cb69faa960e3c565

URL: https://github.com/llvm/llvm-project/commit/22f0c7a45149a55e643665c4cb69faa960e3c565
DIFF: https://github.com/llvm/llvm-project/commit/22f0c7a45149a55e643665c4cb69faa960e3c565.diff

LOG: [mlir][AMDGPU] 8-bit float usage in the AMDGPU dialect

Upcoming AMD hardware will include functions that accept 8-bit floats.
Specifically, there are MFMA instructions that accept 8-bit floats,
either using the same or mixed formats. This patch adds MLIR wrappers
for these intrinsics and explicitly adds support for 8-bit floats in
the gpu-to-rocdl conversion by way of amdgpu-to-rocdl.

Since LLVM does not have f8 types, when targeting LLVM for compilation
on an AMD GPU, both f8 types used on AMD hardware (f8E5M2FNUZ and
f8E4M3FNUZ) are rewritten to i8.

This patch also relaxes the restriction that the types of both source
operands to a amdgpu.mfma instructions match exactly, as this is not
necessarily required for the bf8 (f8E5M2FNUZ) and fp8 (f8E4M3FNUZ)
instructions. In addition, since the buffer_{load,store} operations
maintain a whitelist of permitted types, we add the relevant f8 types
to that list.

This patch does not add any implementations of arithmetic operations
for f8 types.

Reviewed By: jakeh-gc

Differential Revision: https://reviews.llvm.org/D143956

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td
    mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
    mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
    mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
    mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
    mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir
    mlir/test/Dialect/AMDGPU/invalid.mlir
    mlir/test/Dialect/AMDGPU/ops.mlir
    mlir/test/Target/LLVMIR/rocdl.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td
index 13c368676cf4a..6a113443eff1b 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td
@@ -47,10 +47,11 @@ def AMDGPU_RawBufferLoadOp :
                    DefaultValuedAttr<BoolAttr, "true">:$boundsCheck,
                    OptionalAttr<I32Attr>:$indexOffset,
                    Optional<I32>:$sgprOffset)>,
-    Results<(outs AnyTypeOf<[BF16, F16, F32, I32, I8,
+    Results<(outs AnyTypeOf<[BF16, F16, F32, I32, I8, F8E5M2FNUZ, F8E4M3FNUZ,
                               VectorOfLengthAndType<[2, 4], [F32, I32]>,
                               VectorOfLengthAndType<[2, 4, 8], [F16, BF16]>,
-                              VectorOfLengthAndType<[2, 4, 8, 16], [I8]>]>:$value)> {
+                              VectorOfLengthAndType<[2, 4, 8, 16],
+                                [I8, F8E5M2FNUZ, F8E4M3FNUZ]>]>:$value)> {
 
   let summary = "Raw Buffer load, exposing GCN features";
   let description = [{
@@ -96,10 +97,11 @@ def AMDGPU_RawBufferLoadOp :
 def AMDGPU_RawBufferStoreOp :
     AMDGPU_Op<"raw_buffer_store", [AllElementTypesMatch<["value", "memref"]>,
       AttrSizedOperandSegments]>,
-    Arguments<(ins AnyTypeOf<[BF16, F16, F32, I32, I8,
+    Arguments<(ins AnyTypeOf<[BF16, F16, F32, I32, I8, F8E5M2FNUZ, F8E4M3FNUZ,
                               VectorOfLengthAndType<[2, 4], [F32, I32]>,
                               VectorOfLengthAndType<[2, 4, 8], [F16, BF16]>,
-                              VectorOfLengthAndType<[2, 4, 8, 16], [I8]>]>:$value,
+                              VectorOfLengthAndType<[2, 4, 8, 16],
+                                [I8, F8E5M2FNUZ, F8E4M3FNUZ]>]>:$value,
                    Arg<AnyMemRef, "buffer to store to", [MemWrite]>:$memref,
                    Variadic<I32>:$indices,
                    DefaultValuedAttr<BoolAttr, "true">:$boundsCheck,
@@ -215,15 +217,15 @@ def MFMAInTypes : AnyTypeOf<[F32, F64, I32, I64,
                              VectorOfLengthAndType<[2], [F32]>,
                              VectorOfLengthAndType<[4], [F16]>,
                              VectorOfLengthAndType<[2, 4], [BF16]>,
-                             VectorOfLengthAndType<[4, 8], [I8]>]>;
+                             VectorOfLengthAndType<[4, 8], [I8]>,
+                             VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ]>]>;
 def MFMAOutTypes : AnyTypeOf<[F64,
                               VectorOfLengthAndType<[4, 16, 32], [F32]>,
                               VectorOfLengthAndType<[4, 16, 32], [I32]>,
                               VectorOfLengthAndType<[4], [F64]>]>;
 
 def AMDGPU_MFMAOp :
-    AMDGPU_Op<"mfma", [AllTypesMatch<["sourceA", "sourceB"]>,
-                        AllTypesMatch<["destC", "destD"]>,
+    AMDGPU_Op<"mfma", [AllTypesMatch<["destC", "destD"]>,
                         Pure]>,
     Arguments<(ins
                    I32Attr:$m,
@@ -274,7 +276,7 @@ def AMDGPU_MFMAOp :
     $sourceA `*` $sourceB `+` $destC
     attr-dict
     `blgp` `=` $blgp
-    `:` type($sourceA) `,` type($destC)
+    `:` type($sourceA) `,` type($sourceB) `,` type($destC)
   }];
   let hasVerifier = 1;
 }

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 862a8e1004556..c3bea1cfb1ba5 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -172,6 +172,15 @@ def ROCDL_mfma_i32_16x16x32_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x32.i8">;
 def ROCDL_mfma_i32_32x32x16_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x16.i8">;
 def ROCDL_mfma_f32_16x16x8_xf32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x8.xf32">;
 def ROCDL_mfma_f32_32x32x4_xf32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4.xf32">;
+// fp8, only on gfx940
+def ROCDL_mfma_f32_16x16x32_bf8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.bf8.bf8">;
+def ROCDL_mfma_f32_16x16x32_bf8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.bf8.fp8">;
+def ROCDL_mfma_f32_16x16x32_fp8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.fp8.bf8">;
+def ROCDL_mfma_f32_16x16x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.fp8.fp8">;
+def ROCDL_mfma_f32_32x32x16_bf8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf8.bf8">;
+def ROCDL_mfma_f32_32x32x16_bf8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf8.fp8">;
+def ROCDL_mfma_f32_32x32x16_fp8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.fp8.bf8">;
+def ROCDL_mfma_f32_32x32x16_fp8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.fp8.fp8">;
 
 //===---------------------------------------------------------------------===//
 // Vector buffer load/store intrinsics

diff  --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index c1f720ff5d00b..e4dcb27290ee1 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -404,6 +404,45 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
     if (m == 4 && n == 4 && k == 4 && b == 4)
       return ROCDL::mfma_f64_4x4x4f64::getOperationName();
   }
+
+  if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() &&
+      chipset.minorVersion >= 0x40) {
+    // Known to be correct because there are no scalar f8 instructions and
+    // because a length mismatch will have been caught by the verifier.
+    Type sourceBElem =
+        mfma.getSourceB().getType().cast<VectorType>().getElementType();
+    if (m == 16 && n == 16 && k == 32 && b == 1) {
+      if (sourceBElem.isFloat8E5M2FNUZ())
+        return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
+      if (sourceBElem.isFloat8E4M3FNUZ())
+        return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
+    }
+    if (m == 32 && n == 32 && k == 16 && b == 1) {
+      if (sourceBElem.isFloat8E5M2FNUZ())
+        return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
+      if (sourceBElem.isFloat8E4M3FNUZ())
+        return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
+    }
+  }
+
+  if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() &&
+      chipset.minorVersion >= 0x40) {
+    Type sourceBElem =
+        mfma.getSourceB().getType().cast<VectorType>().getElementType();
+    if (m == 16 && n == 16 && k == 32 && b == 1) {
+      if (sourceBElem.isFloat8E5M2FNUZ())
+        return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
+      if (sourceBElem.isFloat8E4M3FNUZ())
+        return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
+    }
+    if (m == 32 && n == 32 && k == 16 && b == 1) {
+      if (sourceBElem.isFloat8E5M2FNUZ())
+        return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
+      if (sourceBElem.isFloat8E4M3FNUZ())
+        return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
+    }
+  }
+
   return std::nullopt;
 }
 
@@ -475,6 +514,14 @@ struct ConvertAMDGPUToROCDLPass
 void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                                    RewritePatternSet &patterns,
                                                    Chipset chipset) {
+  // ROCDL supports fp8 types in some contexts, but there is no LLVM-level f8
+  // type. Therefore, for this target, declare f8 to be equal to i8.
+  converter.addConversion([](FloatType type) -> std::optional<Type> {
+    if (type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ())
+      return IntegerType::get(type.getContext(), 8);
+    return std::nullopt;
+  });
+
   patterns.add<LDSBarrierOpLowering>(converter);
   patterns.add<
       RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawBufferLoadOp>,

diff  --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 57bdfd57582c8..a19edd513bb73 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -189,6 +189,24 @@ LogicalResult MFMAOp::verify() {
     destElem = destVector.getElementType();
   }
 
+  Type sourceBType = getSourceB().getType();
+  if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ()) {
+    int64_t sourceBLen = 1;
+    Type sourceBElem = sourceBType;
+    if (auto sourceBVector = sourceBType.dyn_cast<VectorType>()) {
+      sourceBLen = sourceBVector.getNumElements();
+      sourceBElem = sourceBVector.getElementType();
+    }
+    if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ())
+      return emitOpError("expected both source operands to have f8 elements");
+    if (sourceLen != sourceBLen)
+      return emitOpError(
+          "expected both f8 source vectors to have the same length");
+  } else {
+    if (sourceType != sourceBType)
+      return emitOpError(
+          "expected both non-f8 source operand types to match exactly");
+  }
   // Normalize the wider integer types the compiler expects to i8
   if (sourceElem.isInteger(32)) {
     sourceLen *= 4;

diff  --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 3fab11a49e2fa..9f56711e7f460 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -49,6 +49,7 @@ func.func @gpu_gcn_raw_buffer_load_i8(%buf: memref<64xi8>, %idx: i32) -> i8 {
   %0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%idx] : memref<64xi8>, i32 -> i8
   func.return %0 : i8
 }
+
 // CHECK-LABEL: func @gpu_gcn_raw_buffer_load_2xi8
 func.func @gpu_gcn_raw_buffer_load_2xi8(%buf: memref<64xi8>, %idx: i32) -> vector<2xi8> {
   // CHECK: %[[numRecords:.*]] = llvm.mlir.constant(64 : i32)
@@ -69,6 +70,29 @@ func.func @gpu_gcn_raw_buffer_load_16xi8(%buf: memref<64xi8>, %idx: i32) -> vect
   func.return %0 : vector<16xi8>
 }
 
+// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_f8E5M2FNUZ
+func.func @gpu_gcn_raw_buffer_load_f8E5M2FNUZ(%buf: memref<64xf8E5M2FNUZ>, %idx: i32) -> f8E5M2FNUZ {
+  // CHECK: %[[numRecords:.*]] = llvm.mlir.constant(64 : i32)
+  // CHECK: llvm.insertelement{{.*}}%[[numRecords]]
+  // CHECK: %[[loaded:.*]] = rocdl.raw.buffer.load %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : i8
+  // CHECK: %[[ret:.*]] = builtin.unrealized_conversion_cast %[[loaded]] : i8 to f8E5M2FNUZ
+  // CHECK: return %[[ret]]
+  %0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%idx] : memref<64xf8E5M2FNUZ>, i32 -> f8E5M2FNUZ
+  func.return %0 : f8E5M2FNUZ
+}
+
+// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_4xf8E4M3FNUZ
+func.func @gpu_gcn_raw_buffer_load_4xf8E4M3FNUZ(%buf: memref<64xf8E4M3FNUZ>, %idx: i32) -> vector<4xf8E4M3FNUZ> {
+  // CHECK: %[[numRecords:.*]] = llvm.mlir.constant(64 : i32)
+  // CHECK: llvm.insertelement{{.*}}%[[numRecords]]
+  // CHECK: %[[loaded:.*]] = rocdl.raw.buffer.load %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : i32
+  // CHECK: %[[cast:.*]] = llvm.bitcast %[[loaded]] : i32 to vector<4xi8>
+  // CHECK: %[[ret:.*]] = builtin.unrealized_conversion_cast %[[cast]] : vector<4xi8> to vector<4xf8E4M3FNUZ>
+  // CHECK: return %[[ret]]
+  %0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%idx] : memref<64xf8E4M3FNUZ>, i32 -> vector<4xf8E4M3FNUZ>
+  func.return %0 : vector<4xf8E4M3FNUZ>
+}
+
 // Since the lowering logic is shared with loads, only bitcasts need to be rechecked
 // CHECK-LABEL: func @gpu_gcn_raw_buffer_store_i32
 func.func @gpu_gcn_raw_buffer_store_i32(%value: i32, %buf: memref<64xi32>, %idx: i32) {

diff  --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir
index 9117cd7b2126c..55b67ce96d240 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir
@@ -6,68 +6,86 @@ func.func @mfma_to_rocdl(%arg0 : f32, %arg1 : vector<32xf32>,
                     %arg8 : vector<4xi32>, %arg9 : vector<2xbf16>,
                     %arg10 : vector<4xbf16>, %arg11 : f64,
                     %arg12 : vector<4xf64>, %arg13 : vector<8xi8>,
-                    %arg14 : vector<2xf32>) {
+                    %arg14 : vector<2xf32>, %arg15 : vector<8xf8E5M2FNUZ>,
+                    %arg16 : vector<8xf8E4M3FNUZ>) {
   // CHECK: rocdl.mfma.f32.32x32x1f32{{.*}}: (f32, f32, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
-  amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 1 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 }  blgp = none : f32, vector<32xf32>
+  amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 1 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 }  blgp = none : f32, f32, vector<32xf32>
   // CHECK: rocdl.mfma.f32.16x16x1f32{{.*}}: (f32, f32, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
-  amdgpu.mfma %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 1 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 }  blgp = none : f32, vector<16xf32>
+  amdgpu.mfma %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 1 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 }  blgp = none : f32, f32, vector<16xf32>
   // CHECK: rocdl.mfma.f32.4x4x1f32{{.*}}: (f32, f32, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
-  amdgpu.mfma %arg0 * %arg0 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 1 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 }  blgp = none : f32, vector<4xf32>
+  amdgpu.mfma %arg0 * %arg0 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 1 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 }  blgp = none : f32, f32, vector<4xf32>
   // CHECK: rocdl.mfma.f32.32x32x2f32{{.*}}: (f32, f32, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
-  amdgpu.mfma %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : f32, vector<16xf32>
+  amdgpu.mfma %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : f32, f32, vector<16xf32>
   // CHECK: rocdl.mfma.f32.16x16x4f32{{.*}}: (f32, f32, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
-  amdgpu.mfma %arg0 * %arg0 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : f32, vector<4xf32>
+  amdgpu.mfma %arg0 * %arg0 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : f32, f32, vector<4xf32>
   // CHECK: rocdl.mfma.f32.32x32x4f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
-  amdgpu.mfma %arg4 * %arg4 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 }  blgp = none : vector<4xf16>, vector<32xf32>
+  amdgpu.mfma %arg4 * %arg4 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 }  blgp = none : vector<4xf16>, vector<4xf16>, vector<32xf32>
   // CHECK: rocdl.mfma.f32.16x16x4f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
-  amdgpu.mfma %arg4 * %arg4 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 }  blgp = none : vector<4xf16>, vector<16xf32>
+  amdgpu.mfma %arg4 * %arg4 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 }  blgp = none : vector<4xf16>, vector<4xf16>, vector<16xf32>
   // CHECK: rocdl.mfma.f32.4x4x4f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
-  amdgpu.mfma %arg4 * %arg4 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 }  blgp = none : vector<4xf16>, vector<4xf32>
+  amdgpu.mfma %arg4 * %arg4 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 }  blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
   // CHECK: rocdl.mfma.f32.32x32x8f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
-  amdgpu.mfma %arg4 * %arg4 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<4xf16>, vector<16xf32>
+  amdgpu.mfma %arg4 * %arg4 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<4xf16>, vector<4xf16>, vector<16xf32>
   // CHECK: rocdl.mfma.f32.16x16x16f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
-  amdgpu.mfma %arg4 * %arg4 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<4xf16>, vector<4xf32>
+  amdgpu.mfma %arg4 * %arg4 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
   // CHECK: rocdl.mfma.i32.32x32x4i8{{.*}}: (i32, i32, vector<32xi32>, i32, i32, i32) -> vector<32xi32>
-  amdgpu.mfma %arg5 * %arg5 + %arg6 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 }  blgp = none : vector<4xi8>, vector<32xi32>
+  amdgpu.mfma %arg5 * %arg5 + %arg6 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 }  blgp = none : vector<4xi8>, vector<4xi8>, vector<32xi32>
   // CHECK: rocdl.mfma.i32.16x16x4i8{{.*}}: (i32, i32, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
-  amdgpu.mfma %arg5 * %arg5 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 }  blgp = none : vector<4xi8>, vector<16xi32>
+  amdgpu.mfma %arg5 * %arg5 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 }  blgp = none : vector<4xi8>, vector<4xi8>, vector<16xi32>
   // CHECK: rocdl.mfma.i32.4x4x4i8{{.*}}: (i32, i32, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
-  amdgpu.mfma %arg5 * %arg5 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 }  blgp = none : vector<4xi8>, vector<4xi32>
+  amdgpu.mfma %arg5 * %arg5 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 }  blgp = none : vector<4xi8>, vector<4xi8>, vector<4xi32>
   // CHECK: rocdl.mfma.i32.32x32x8i8{{.*}}: (i32, i32, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
-  amdgpu.mfma %arg5 * %arg5 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<4xi8>, vector<16xi32>
+  amdgpu.mfma %arg5 * %arg5 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<4xi8>, vector<4xi8>, vector<16xi32>
   // CHECK: rocdl.mfma.i32.16x16x16i8{{.*}}: (i32, i32, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
-  amdgpu.mfma %arg5 * %arg5 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<4xi8>, vector<4xi32>
+  amdgpu.mfma %arg5 * %arg5 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<4xi8>, vector<4xi8>, vector<4xi32>
   // CHECK: rocdl.mfma.f32.32x32x2bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
-  amdgpu.mfma %arg9 * %arg9 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 }  blgp = none : vector<2xbf16>, vector<32xf32>
+  amdgpu.mfma %arg9 * %arg9 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 }  blgp = none : vector<2xbf16>, vector<2xbf16>, vector<32xf32>
   // CHECK: rocdl.mfma.f32.16x16x2bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
-  amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 }  blgp = none : vector<2xbf16>, vector<16xf32>
+  amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 }  blgp = none : vector<2xbf16>, vector<2xbf16>, vector<16xf32>
   // CHECK: rocdl.mfma.f32.4x4x2bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
-  amdgpu.mfma %arg9 * %arg9 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 }  blgp = none : vector<2xbf16>, vector<4xf32>
+  amdgpu.mfma %arg9 * %arg9 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 }  blgp = none : vector<2xbf16>, vector<2xbf16>, vector<4xf32>
   // CHECK: rocdl.mfma.f32.32x32x4bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
-  amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<2xbf16>, vector<16xf32>
+  amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<2xbf16>, vector<2xbf16>, vector<16xf32>
   // CHECK: rocdl.mfma.f32.16x16x8bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
-  amdgpu.mfma %arg9 * %arg9 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<2xbf16>, vector<4xf32>
+  amdgpu.mfma %arg9 * %arg9 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<2xbf16>, vector<2xbf16>, vector<4xf32>
   // CHECK: rocdl.mfma.f32.32x32x4bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
-  amdgpu.mfma %arg10 * %arg10 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 }  blgp = none : vector<4xbf16>, vector<32xf32>
+  amdgpu.mfma %arg10 * %arg10 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 }  blgp = none : vector<4xbf16>, vector<4xbf16>, vector<32xf32>
   // CHECK: rocdl.mfma.f32.16x16x4bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
-  amdgpu.mfma %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 }  blgp = none : vector<4xbf16>, vector<16xf32>
+  amdgpu.mfma %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 }  blgp = none : vector<4xbf16>, vector<4xbf16>, vector<16xf32>
   // CHECK: rocdl.mfma.f32.4x4x4bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
-  amdgpu.mfma %arg10 * %arg10 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 }  blgp = none : vector<4xbf16>, vector<4xf32>
+  amdgpu.mfma %arg10 * %arg10 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 }  blgp = none : vector<4xbf16>, vector<4xbf16>, vector<4xf32>
   // CHECK: rocdl.mfma.f32.32x32x8bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
-  amdgpu.mfma %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<4xbf16>, vector<16xf32>
+  amdgpu.mfma %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<4xbf16>, vector<4xbf16>, vector<16xf32>
   // CHECK: rocdl.mfma.f32.16x16x16bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
-  amdgpu.mfma %arg10 * %arg10 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<4xbf16>, vector<4xf32>
+  amdgpu.mfma %arg10 * %arg10 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<4xbf16>, vector<4xbf16>, vector<4xf32>
   // CHECK: rocdl.mfma.f64.16x16x4f64{{.*}}: (f64, f64, vector<4xf64>, i32, i32, i32) -> vector<4xf64>
-  amdgpu.mfma %arg11 * %arg11 + %arg12 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : f64, vector<4xf64>
+  amdgpu.mfma %arg11 * %arg11 + %arg12 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : f64, f64, vector<4xf64>
   // CHECK: rocdl.mfma.f64.4x4x4f64{{.*}}: (f64, f64, f64, i32, i32, i32) -> f64
-  amdgpu.mfma %arg11 * %arg11 + %arg11 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 4 : i32 }  blgp = none : f64, f64
+  amdgpu.mfma %arg11 * %arg11 + %arg11 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 4 : i32 }  blgp = none : f64, f64, f64
   // CHECK: rocdl.mfma.i32.16x16x32.i8{{.*}}: (i64, i64, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
-  amdgpu.mfma %arg13 * %arg13 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<8xi8>, vector<4xi32>
+  amdgpu.mfma %arg13 * %arg13 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<8xi8>, vector<8xi8>, vector<4xi32>
   // CHECK: rocdl.mfma.i32.32x32x16.i8{{.*}}: (i64, i64, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
-  amdgpu.mfma %arg13 * %arg13 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<8xi8>, vector<16xi32>
+  amdgpu.mfma %arg13 * %arg13 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<8xi8>, vector<8xi8>, vector<16xi32>
   // CHECK: rocdl.mfma.f32.16x16x8.xf32{{.*}}: (vector<2xf32>, vector<2xf32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
-  amdgpu.mfma %arg14 * %arg14 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32, reducePrecision }  blgp = none : vector<2xf32>, vector<4xf32>
+  amdgpu.mfma %arg14 * %arg14 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32, reducePrecision }  blgp = none : vector<2xf32>, vector<2xf32>, vector<4xf32>
   // CHECK: rocdl.mfma.f32.32x32x4.xf32{{.*}}: (vector<2xf32>, vector<2xf32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
-  amdgpu.mfma %arg14 * %arg14 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32, reducePrecision }  blgp = none : vector<2xf32>, vector<16xf32>
+  amdgpu.mfma %arg14 * %arg14 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32, reducePrecision }  blgp = none : vector<2xf32>, vector<2xf32>, vector<16xf32>
+  // CHECK: rocdl.mfma.f32.16x16x32.bf8.bf8{{.*}}: (i64, i64, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.mfma %arg15 * %arg15 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<8xf8E5M2FNUZ>, vector<8xf8E5M2FNUZ>, vector<4xf32>
+  // CHECK: rocdl.mfma.f32.16x16x32.bf8.fp8{{.*}}: (i64, i64, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.mfma %arg15 * %arg16 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<8xf8E5M2FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>
+  // CHECK: rocdl.mfma.f32.16x16x32.fp8.bf8{{.*}}: (i64, i64, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.mfma %arg16 * %arg15 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<8xf8E4M3FNUZ>, vector<8xf8E5M2FNUZ>, vector<4xf32>
+  // CHECK: rocdl.mfma.f32.16x16x32.fp8.fp8{{.*}}: (i64, i64, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.mfma %arg16 * %arg16 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>
+  // CHECK: rocdl.mfma.f32.32x32x16.bf8.bf8{{.*}}: (i64, i64, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.mfma %arg15 * %arg15 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<8xf8E5M2FNUZ>, vector<8xf8E5M2FNUZ>, vector<16xf32>
+  // CHECK: rocdl.mfma.f32.32x32x16.bf8.fp8{{.*}}: (i64, i64, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.mfma %arg15 * %arg16 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<8xf8E5M2FNUZ>, vector<8xf8E4M3FNUZ>, vector<16xf32>
+  // CHECK: rocdl.mfma.f32.32x32x16.fp8.bf8{{.*}}: (i64, i64, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.mfma %arg16 * %arg15 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<8xf8E4M3FNUZ>, vector<8xf8E5M2FNUZ>, vector<16xf32>
+  // CHECK: rocdl.mfma.f32.32x32x16.fp8.fp8{{.*}}: (i64, i64, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.mfma %arg16 * %arg16 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<16xf32>
+
   func.return
 }

diff  --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 9ac8038655dd6..82d7af2c6dfa4 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -2,12 +2,34 @@
 
 // -----
 
+func.func @bad_source_types(%a: vector<2xf32>, %b: vector<4xf16>,
+                                %c: vector<32xf32>) -> vector<32xf32> {
+  // expected-error at +1 {{'amdgpu.mfma' op expected both non-f8 source operand types to match exactly}}
+  %d = amdgpu.mfma %a * %b + %c {
+    m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32,
+    abid = 0 : i32, cbsz = 0 : i32} blgp = none : vector<2xf32>, vector<4xf16>, vector<32xf32>
+  func.return %d : vector<32xf32>
+}
+
+// -----
+
+func.func @bad_source_types_f8(%a: vector<8xf8E5M2FNUZ>, %b: vector<8xi8>,
+                                %c: vector<32xf32>) -> vector<32xf32> {
+  // expected-error at +1 {{'amdgpu.mfma' op expected both source operands to have f8 elements}}
+  %d = amdgpu.mfma %a * %b + %c {
+    m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32,
+    abid = 0 : i32, cbsz = 0 : i32} blgp = none : vector<8xf8E5M2FNUZ>, vector<8xi8>, vector<32xf32>
+  func.return %d : vector<32xf32>
+}
+
+// -----
+
 func.func @bad_source_arguments(%a: vector<2xf32>, %b: vector<2xf32>,
                                 %c: vector<32xf32>) -> vector<32xf32> {
   // expected-error at +1 {{'amdgpu.mfma' op expected 1 source values for this operation but got 2}}
   %d = amdgpu.mfma %a * %b + %c {
     m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32,
-    abid = 0 : i32, cbsz = 0 : i32} blgp = none : vector<2xf32>, vector<32xf32>
+    abid = 0 : i32, cbsz = 0 : i32} blgp = none : vector<2xf32>, vector<2xf32>, vector<32xf32>
   func.return %d : vector<32xf32>
 }
 
@@ -18,7 +40,7 @@ func.func @bad_source_arguments_i8(%a: vector<8xi8>, %b: vector<8xi8>,
   // expected-error at +1 {{'amdgpu.mfma' op expected 4 source values for this operation but got 8}}
   %d = amdgpu.mfma %a * %b + %c {
     m = 32 : i32, n = 32 : i32, k = 4 : i32, blocks = 2 : i32,
-    abid = 0 : i32, cbsz = 0 : i32} blgp = none : vector<8xi8>, vector<4xi32>
+    abid = 0 : i32, cbsz = 0 : i32} blgp = none : vector<8xi8>, vector<8xi8>, vector<4xi32>
   func.return %d : vector<4xi32>
 }
 
@@ -28,7 +50,7 @@ func.func @bad_dest_type(%a: f32, %b: f32, %c: vector<16xf32>) -> vector<16xf32>
   // expected-error at +1 {{'amdgpu.mfma' op expected 32 result values for this operation but got 16}}
   %d = amdgpu.mfma %a * %b + %c {
     m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32,
-    abid = 0 : i32, cbsz = 0 : i32} blgp = none : f32, vector<16xf32>
+    abid = 0 : i32, cbsz = 0 : i32} blgp = none : f32, f32, vector<16xf32>
   return %d : vector<16xf32>
 }
 
@@ -38,7 +60,7 @@ func.func @f64_permuting_b(%a: f64, %b: f64, %c: vector<4xf64>) -> vector<4xf64>
   // expected-error at +1 {{'amdgpu.mfma' op double-precision ops do not support permuting lanes of B}}
   %d = amdgpu.mfma %a * %b + %c {
     m = 16 : i32, n = 16 : i32, k = 4 : i32, blocks = 1 : i32,
-    abid = 0 : i32, cbsz = 0 : i32} blgp = bcast_first_32 : f64, vector<4xf64>
+    abid = 0 : i32, cbsz = 0 : i32} blgp = bcast_first_32 : f64, f64, vector<4xf64>
   return %d : vector<4xf64>
 }
 
@@ -48,7 +70,7 @@ func.func @f64_permuting_a(%a: f64, %b: f64, %c: vector<4xf64>) -> vector<4xf64>
   // expected-error at +1 {{'amdgpu.mfma' op double-precision ops do not support permuting lanes of A}}
   %d = amdgpu.mfma %a * %b + %c {
     m = 16 : i32, n = 16 : i32, k = 4 : i32, blocks = 1 : i32,
-    abid = 0 : i32, cbsz = 1 : i32} blgp = none : f64, vector<4xf64>
+    abid = 0 : i32, cbsz = 1 : i32} blgp = none : f64, f64, vector<4xf64>
   return %d : vector<4xf64>
 }
 
@@ -58,7 +80,7 @@ func.func @abid_without_bradcast(%a: f32, %b: f32, %c: vector<32xf32>) -> vector
   // expected-error at +1 {{'amdgpu.mfma' op block ID for permuting A (abid) must be below 2 ** cbsz}}
   %d = amdgpu.mfma %a * %b + %c {
     m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32,
-    abid = 1 : i32, cbsz = 0 : i32} blgp = none : f32, vector<32xf32>
+    abid = 1 : i32, cbsz = 0 : i32} blgp = none : f32, f32, vector<32xf32>
   func.return %d : vector<32xf32>
 }
 
@@ -68,7 +90,7 @@ func.func @abid_too_large(%a: f32, %b: f32, %c: vector<32xf32>) -> vector<32xf32
   // expected-error at +1 {{'amdgpu.mfma' op block ID for permuting A (abid) must be below 2 ** cbsz}}
   %d = amdgpu.mfma %a * %b + %c {
     m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32,
-    abid = 2 : i32, cbsz = 1 : i32} blgp = none : f32, vector<32xf32>
+    abid = 2 : i32, cbsz = 1 : i32} blgp = none : f32, f32, vector<32xf32>
   func.return %d : vector<32xf32>
 }
 
@@ -78,6 +100,6 @@ func.func @no_negation(%a: f32, %b: f32, %c: vector<32xf32>) -> vector<32xf32> {
   // expected-error at +1 {{'amdgpu.mfma' op negation flags only available for double-precision operations}}
   %d = amdgpu.mfma %a * %b + %c {
     m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32,
-    abid = 0 : i32, cbsz = 0 : i32, negateA} blgp = none : f32, vector<32xf32>
+    abid = 0 : i32, cbsz = 0 : i32, negateA} blgp = none : f32, f32, vector<32xf32>
   func.return %d : vector<32xf32>
 }

diff  --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index d303472c0e089..0e13e1ecfd66c 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -70,6 +70,6 @@ func.func @lds_barrier() {
 // CHECK-LABEL: func @mfma
 func.func @mfma(%arg0 : f32, %arg1 : vector<32xf32>) -> vector<32xf32> {
   // CHECK: amdgpu.mfma
-  %0 = amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 1 : i32, cbsz = 1 : i32, k = 1 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = bcast_second_32 : f32, vector<32xf32>
+  %0 = amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 1 : i32, cbsz = 1 : i32, k = 1 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = bcast_second_32 : f32, f32, vector<32xf32>
   func.return %0 : vector<32xf32>
 }

diff  --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index a6ca45ce31918..49a01f66c9bc4 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -69,7 +69,7 @@ llvm.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
                    %arg4 : vector<16 x f32>, %arg5 : vector<4xf32>,
                    %arg6 : vector<4xf16>, %arg7 : vector<32 x i32>,
                    %arg8 : vector<16 x i32>, %arg9 : vector<4xi32>,
-                   %arg10 : vector<2xi16>) -> vector<32 x f32> {
+                   %arg10 : vector<2xi16>, %arg11 : i64) -> vector<32 x f32> {
   %csti32 = llvm.mlir.constant(42 : i32) : i32
 
   // CHECK-LABEL: rocdl.xdlops
@@ -173,6 +173,45 @@ llvm.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
                             (vector<2xi16>, vector<2xi16>, vector<4xf32>,
                             i32, i32, i32) -> vector<4xf32>
 
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x32.bf8.bf8(i64 %{{.*}}, i64 %{{.*}}, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
+  %r20 = rocdl.mfma.f32.16x16x32.bf8.bf8 %arg11, %arg11, %arg5, %csti32, %csti32, %csti32 :
+                            (i64, i64, vector<4xf32>,
+                            i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x32.bf8.fp8(i64 %{{.*}}, i64 %{{.*}}, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
+  %r21 = rocdl.mfma.f32.16x16x32.bf8.fp8 %arg11, %arg11, %arg5, %csti32, %csti32, %csti32 :
+                            (i64, i64, vector<4xf32>,
+                            i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x32.fp8.bf8(i64 %{{.*}}, i64 %{{.*}}, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
+  %r22 = rocdl.mfma.f32.16x16x32.fp8.bf8 %arg11, %arg11, %arg5, %csti32, %csti32, %csti32 :
+                            (i64, i64, vector<4xf32>,
+                            i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x32.fp8.fp8(i64 %{{.*}}, i64 %{{.*}}, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
+  %r23 = rocdl.mfma.f32.16x16x32.fp8.fp8 %arg11, %arg11, %arg5, %csti32, %csti32, %csti32 :
+                            (i64, i64, vector<4xf32>,
+                            i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x16.bf8.bf8(i64 %{{.*}}, i64 %{{.*}}, <16 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
+  %r24 = rocdl.mfma.f32.32x32x16.bf8.bf8 %arg11, %arg11, %arg4, %csti32, %csti32, %csti32 :
+                            (i64, i64, vector<16xf32>,
+                            i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x16.bf8.fp8(i64 %{{.*}}, i64 %{{.*}}, <16 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
+  %r25 = rocdl.mfma.f32.32x32x16.bf8.fp8 %arg11, %arg11, %arg4, %csti32, %csti32, %csti32 :
+                            (i64, i64, vector<16xf32>,
+                            i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x16.fp8.bf8(i64 %{{.*}}, i64 %{{.*}}, <16 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
+  %r26 = rocdl.mfma.f32.32x32x16.fp8.bf8 %arg11, %arg11, %arg4, %csti32, %csti32, %csti32 :
+                            (i64, i64, vector<16xf32>,
+                            i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x16.bf8.bf8(i64 %{{.*}}, i64 %{{.*}}, <16 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
+  %r27 = rocdl.mfma.f32.32x32x16.bf8.bf8 %arg11, %arg11, %arg4, %csti32, %csti32, %csti32 :
+                            (i64, i64, vector<16xf32>,
+                            i32, i32, i32) -> vector<16xf32>
   llvm.return %r0 : vector<32 x f32>
 }
 


        


More information about the Mlir-commits mailing list