[Mlir-commits] [mlir] c118864 - [MLIR][ROCDL]Add MFMA_*_F8F6F4 instructions to the ROCDL dialect (#123830)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 23 11:27:59 PST 2025


Author: Yi Qian
Date: 2025-01-23T19:27:56Z
New Revision: c118864223c6309378cd704f3406533474c2759f

URL: https://github.com/llvm/llvm-project/commit/c118864223c6309378cd704f3406533474c2759f
DIFF: https://github.com/llvm/llvm-project/commit/c118864223c6309378cd704f3406533474c2759f.diff

LOG: [MLIR][ROCDL]Add MFMA_*_F8F6F4 instructions to the ROCDL dialect (#123830)

This PR adds mfma.scale.f32.32x32x64.f8f6f4 and
mfma.scale.f32.16x16x128.f8f6f4 to the ROCDL dialect. They are converted
to the corresponding intrinsics in the mlir-to-llvmir pass.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
    mlir/test/Dialect/LLVMIR/rocdl.mlir
    mlir/test/Target/LLVMIR/rocdl.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index e9e62a74237c4f..95fbe7ed66a434 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -343,6 +343,18 @@ class ROCDL_Mfma_IntrOp<string mnemonic, list<Trait> traits = []> :
     "$args attr-dict `:` functional-type($args, $res)";
 }
 
+//===---------------------------------------------------------------------===//
+// MFMA intrinsics with overloaded operands
+class ROCDL_Mfma_OO_IntrOp<string mnemonic, list<int> overloadedOperands,
+                        list<Trait> traits = []> :
+  LLVM_IntrOpBase<ROCDL_Dialect, mnemonic,
+                  "amdgcn_" # !subst(".","_", mnemonic),
+                  [], overloadedOperands, traits, 1>,
+  Arguments<(ins Variadic<LLVM_Type>:$args)> {
+  let assemblyFormat =
+    "$args attr-dict `:` functional-type($args, $res)";
+}
+
 // Available on all CDNA.
 def ROCDL_mfma_f32_32x32x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x1f32">;
 def ROCDL_mfma_f32_16x16x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x1f32">;
@@ -394,7 +406,8 @@ def ROCDL_mfma_f32_16x16x32_f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.f16">;
 def ROCDL_mfma_f32_32x32x16_bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf16">;
 def ROCDL_mfma_i32_32x32x32_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x32.i8">;
 def ROCDL_mfma_f32_32x32x16_f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.f16">;
-
+def ROCDL_mfma_scale_f32_16x16x128_f8f6f4 : ROCDL_Mfma_OO_IntrOp<"mfma.scale.f32.16x16x128.f8f6f4", [0,1]>;
+def ROCDL_mfma_scale_f32_32x32x64_f8f6f4 : ROCDL_Mfma_OO_IntrOp<"mfma.scale.f32.32x32x64.f8f6f4", [0,1]>;
 //===---------------------------------------------------------------------===//
 // WMMA intrinsics
 class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands,

diff  --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index c80ebebaafe3ad..712f8c2a1caf66 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -66,7 +66,8 @@ func.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
                    %arg8 : vector<16xi32>, %arg9 : vector<4xi32>,
                    %arg10 : vector<2xi16>, %arg11 : vector<4xi16>,
                    %arg12 : vector<4xf64>, %arg13 : f64,
-                   %arg14 : i64, %arg15 : vector<2xf32>) {
+                   %arg14 : i64, %arg15 : vector<2xf32>,
+                   %arg16: vector<8xbf16>, %arg17 : vector<8xf16>) {
   // CHECK-LABEL: rocdl.xdlops
   // CHECK: rocdl.mfma.f32.32x32x1f32 {{.*}} : (f32, f32, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
   %r0 = rocdl.mfma.f32.32x32x1f32 %arg0, %arg1, %arg2, %arg3, %arg3, %arg3 :
@@ -224,6 +225,312 @@ func.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
                             (vector<2xf32>, vector<2xf32>, vector<16xf32>,
                             i32, i32, i32) -> vector<16xf32>
 
+  // CHECK: rocdl.mfma.f32.16x16x32.bf16 {{.*}} : (vector<8xbf16>, vector<8xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  %r31 = rocdl.mfma.f32.16x16x32.bf16 %arg16, %arg16, %arg5, %arg3, %arg3, %arg3 :
+                              (vector<8xbf16>, vector<8xbf16>, vector<4xf32>,
+                               i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: rocdl.mfma.i32.16x16x64.i8 {{.*}} : (vector<4xi32>, vector<4xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+  %r32 = rocdl.mfma.i32.16x16x64.i8 %arg9, %arg9, %arg9, %arg3, %arg3, %arg3 :
+                              (vector<4xi32>, vector<4xi32>, vector<4xi32>,
+                               i32, i32, i32) -> vector<4xi32>
+
+  // CHECK: rocdl.mfma.f32.16x16x32.f16 {{.*}} : (vector<8xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xi32>
+  %r33 = rocdl.mfma.f32.16x16x32.f16 %arg17, %arg17, %arg5, %arg3, %arg3, %arg3 :
+                               (vector<8xf16>, vector<8xf16>, vector<4xf32>,
+                                i32, i32, i32) -> vector<4xi32>
+
+  // CHECK: rocdl.mfma.f32.32x32x16.bf16 {{.*}} : (vector<8xbf16>, vector<8xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  %r34 = rocdl.mfma.f32.32x32x16.bf16 %arg16, %arg16, %arg4, %arg3, %arg3, %arg3 :
+                               (vector<8xbf16>, vector<8xbf16>, vector<16xf32>,
+                                i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: rocdl.mfma.i32.32x32x32.i8 {{.*}} : (vector<4xi32>, vector<4xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
+  %r35 = rocdl.mfma.i32.32x32x32.i8 %arg9, %arg9, %arg8, %arg3, %arg3, %arg3 :
+                               (vector<4xi32>, vector<4xi32>, vector<16xi32>,
+                                i32, i32, i32) -> vector<16xi32>
+
+  // CHECK: rocdl.mfma.f32.32x32x16.f16 {{.*}} : (vector<8xf16>, vector<8xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  %r36 = rocdl.mfma.f32.32x32x16.f16 %arg17, %arg17, %arg4, %arg3, %arg3, %arg3 :
+                               (vector<8xf16>, vector<8xf16>, vector<16xf32>,
+                                i32, i32, i32) -> vector<16xf32>
+
+  llvm.return
+}
+
+llvm.func @rocdl.mfma.scale.f32.32x32x64.f8f6f4(%arg0 : i32,
+                   %arg1 : vector<16 x f32>, %arg2 : vector<8xi32>,
+                   %arg3 : vector<6xi32>, %arg4 : vector<4xi32>) {
+  %cst0 = llvm.mlir.constant(0 : i32) : i32
+  %cst1 = llvm.mlir.constant(1 : i32) : i32
+  %cst2 = llvm.mlir.constant(2 : i32) : i32
+  %cst3 = llvm.mlir.constant(3 : i32) : i32
+  %cst4 = llvm.mlir.constant(4 : i32) : i32
+
+  // CHECK-LABEL: rocdl.mfma.scale.f32.32x32x64.f8f6f4
+  // fp8 * fp8
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  %r00 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg2, %arg1, %cst0, %cst0, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp8 * bf8
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  %r01 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg2, %arg1, %cst0, %cst1, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp8 * fp6
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  %r02 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg3, %arg1, %cst0, %cst2, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp8 * bf6
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  %r03 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg3, %arg1, %cst0, %cst3, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp8 * fp4
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<8xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  %r04 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg4, %arg1, %cst0, %cst4, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // bf8 * fp8
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  %r10 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg2, %arg1, %cst1, %cst0, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // bf8 * bf8
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  %r11 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg2, %arg1, %cst1, %cst1, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // bf8 * fp6
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  %r12 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg3, %arg1, %cst1, %cst2, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // bf8 * bf6
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  %r13 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg3, %arg1, %cst1, %cst3, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // bf8 * fp4
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<8xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  %r14 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg4, %arg1, %cst1, %cst4, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp6 * fp8
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  %r20 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg2, %arg1, %cst2, %cst0, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp6 * bf8
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  %r21 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg2, %arg1, %cst2, %cst1, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp6 * fp6
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  %r22 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg3, %arg1, %cst2, %cst2, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp6 * bf6
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  %r23 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg3, %arg1, %cst2, %cst3, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp6 * fp4
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<6xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  %r24 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg4, %arg1, %cst2, %cst4, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // bf6 * fp8
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  %r30 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg2, %arg1, %cst3, %cst0, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // bf6 * bf8
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  %r31 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg2, %arg1, %cst3, %cst1, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // bf6 * fp6
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  %r32 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg3, %arg1, %cst3, %cst2, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // bf6 * bf6
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  %r33 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg3, %arg1, %cst3, %cst3, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // bf6 * fp4
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<6xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  %r34 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg4, %arg1, %cst3, %cst4, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp4 * fp8
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  %r40 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg2, %arg1, %cst4, %cst0, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp4 * bf8
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  %r41 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg2, %arg1, %cst4, %cst1, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp4 * fp6
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<4xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  %r42 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg3, %arg1, %cst4, %cst2, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<4xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp4 * bf6
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<4xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  %r43 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg3, %arg1, %cst4, %cst3, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<4xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp4 * fp4
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  %r44 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg4, %arg1, %cst4, %cst4, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  llvm.return
+}
+
+llvm.func @rocdl.mfma.scale.f32.16x16x128.f8f6f4(%arg0 : i32,
+                   %arg1 : vector<4 x f32>, %arg2 : vector<8xi32>,
+                   %arg3 : vector<6xi32>, %arg4 : vector<4xi32>) {
+  %cst0 = llvm.mlir.constant(0 : i32) : i32
+  %cst1 = llvm.mlir.constant(1 : i32) : i32
+  %cst2 = llvm.mlir.constant(2 : i32) : i32
+  %cst3 = llvm.mlir.constant(3 : i32) : i32
+  %cst4 = llvm.mlir.constant(4 : i32) : i32
+
+  // CHECK-LABEL: rocdl.mfma.scale.f32.16x16x128.f8f6f4
+  // fp8 * fp8
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  %r00 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg2, %arg1, %cst0, %cst0, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp8 * bf8
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  %r01 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg2, %arg1, %cst0, %cst1, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp8 * fp6
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  %r02 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg3, %arg1, %cst0, %cst2, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp8 * bf6
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  %r03 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg3, %arg1, %cst0, %cst3, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp8 * fp4
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<8xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  %r04 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg4, %arg1, %cst0, %cst4, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // bf8 * fp8
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  %r10 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg2, %arg1, %cst1, %cst0, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // bf8 * bf8
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  %r11 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg2, %arg1, %cst1, %cst1, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // bf8 * fp6
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  %r12 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg3, %arg1, %cst1, %cst2, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // bf8 * bf6
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  %r13 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg3, %arg1, %cst1, %cst3, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // bf8 * fp4
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<8xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  %r14 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg4, %arg1, %cst1, %cst4, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp6 * fp8
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  %r20 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg2, %arg1, %cst2, %cst0, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp6 * bf8
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  %r21 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg2, %arg1, %cst2, %cst1, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp6 * fp6
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  %r22 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg3, %arg1, %cst2, %cst2, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp6 * bf6
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  %r23 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg3, %arg1, %cst2, %cst3, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp6 * fp4
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<6xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  %r24 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg4, %arg1, %cst2, %cst4, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // bf6 * fp8
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  %r30 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg2, %arg1, %cst3, %cst0, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // bf6 * bf8
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  %r31 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg2, %arg1, %cst3, %cst1, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // bf6 * fp6
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  %r32 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg3, %arg1, %cst3, %cst2, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // bf6 * bf6
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  %r33 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg3, %arg1, %cst3, %cst3, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // bf6 * fp4
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<6xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  %r34 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg4, %arg1, %cst3, %cst4, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp4 * fp8
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  %r40 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4, %arg2, %arg1, %cst4, %cst0, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp4 * bf8
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  %r41 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4, %arg2, %arg1, %cst4, %cst1, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp4 * fp6
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<4xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  %r42 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4, %arg3, %arg1, %cst4, %cst2, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<4xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp4 * bf6
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<4xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  %r43 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4, %arg3, %arg1, %cst4, %cst3, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<4xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp4 * fp4
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  %r44 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4, %arg4, %arg1, %cst4, %cst4, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
   llvm.return
 }
 

diff  --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 8879ba02b24057..b74edb62106837 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -398,6 +398,282 @@ llvm.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
   llvm.return %r0 : vector<32 x f32>
 }
 
+llvm.func @rocdl.mfma.scale.f32.32x32x64.f8f6f4(%arg0 : i32,
+                   %arg1 : vector<16 x f32>, %arg2 : vector<8xi32>,
+                   %arg3 : vector<6xi32>, %arg4 : vector<4xi32>) -> vector<16 x f32> {
+  %cst0 = llvm.mlir.constant(0 : i32) : i32
+  %cst1 = llvm.mlir.constant(1 : i32) : i32
+  %cst2 = llvm.mlir.constant(2 : i32) : i32
+  %cst3 = llvm.mlir.constant(3 : i32) : i32
+  %cst4 = llvm.mlir.constant(4 : i32) : i32
+
+  // CHECK-LABEL: rocdl.mfma.scale.f32.32x32x64.f8f6f4
+  // fp8 * fp8
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v8i32(<8 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 0, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r00 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg2, %arg1, %cst0, %cst0, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp8 * bf8
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v8i32(<8 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 0, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r01 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg2, %arg1, %cst0, %cst1, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp8 * fp6
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v6i32(<8 x i32> %{{.*}}, <6 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 0, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r02 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg3, %arg1, %cst0, %cst2, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp8 * bf6
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v6i32(<8 x i32> %{{.*}}, <6 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 0, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r03 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg3, %arg1, %cst0, %cst3, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp8 * fp4
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v4i32(<8 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 0, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r04 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg4, %arg1, %cst0, %cst4, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // bf8 * fp8
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v8i32(<8 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 1, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r10 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg2, %arg1, %cst1, %cst0, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // bf8 * bf8
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v8i32(<8 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 1, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r11 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg2, %arg1, %cst1, %cst1, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // bf8 * fp6
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v6i32(<8 x i32> %{{.*}}, <6 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 1, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r12 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg3, %arg1, %cst1, %cst2, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // bf8 * bf6
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v6i32(<8 x i32> %{{.*}}, <6 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 1, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r13 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg3, %arg1, %cst1, %cst3, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // bf8 * fp4
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v4i32(<8 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 1, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r14 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg4, %arg1, %cst1, %cst4, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp6 * fp8
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v8i32(<6 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 2, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r20 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg2, %arg1, %cst2, %cst0, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp6 * bf8
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v8i32(<6 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 2, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r21 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg2, %arg1, %cst2, %cst1, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp6 * fp6
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v6i32(<6 x i32> %{{.*}}, <6 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 2, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r22 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg3, %arg1, %cst2, %cst2, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp6 * bf6
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v6i32(<6 x i32> %{{.*}}, <6 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 2, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r23 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg3, %arg1, %cst2, %cst3, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp6 * fp4
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v4i32(<6 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 2, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r24 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg4, %arg1, %cst2, %cst4, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // bf6 * fp8
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v8i32(<6 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 3, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r30 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg2, %arg1, %cst3, %cst0, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // bf6 * bf8
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v8i32(<6 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 3, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r31 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg2, %arg1, %cst3, %cst1, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // bf6 * fp6
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v6i32(<6 x i32> %{{.*}}, <6 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 3, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r32 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg3, %arg1, %cst3, %cst2, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // bf6 * bf6
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v6i32(<6 x i32> %{{.*}}, <6 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 3, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r33 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg3, %arg1, %cst3, %cst3, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // bf6 * fp4
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v4i32(<6 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 3, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r34 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg4, %arg1, %cst3, %cst4, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp4 * fp8
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v4i32.v8i32(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 4, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r40 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg2, %arg1, %cst4, %cst0, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp4 * bf8
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v4i32.v8i32(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 4, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r41 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg2, %arg1, %cst4, %cst1, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp4 * fp6
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v4i32.v6i32(<4 x i32> %{{.*}}, <6 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 4, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r42 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg3, %arg1, %cst4, %cst2, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<4xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp4 * bf6
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v4i32.v6i32(<4 x i32> %{{.*}}, <6 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 4, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r43 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg3, %arg1, %cst4, %cst3, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<4xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp4 * fp4
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v4i32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 4, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r44 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg4, %arg1, %cst4, %cst4, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  llvm.return %r00 : vector<16 x f32>
+}
+
+llvm.func @rocdl.mfma.scale.f32.16x16x128.f8f6f4(%arg0 : i32,
+                   %arg1 : vector<4 x f32>, %arg2 : vector<8xi32>,
+                   %arg3 : vector<6xi32>, %arg4 : vector<4xi32>) -> vector<4 x f32> {
+  %cst0 = llvm.mlir.constant(0 : i32) : i32
+  %cst1 = llvm.mlir.constant(1 : i32) : i32
+  %cst2 = llvm.mlir.constant(2 : i32) : i32
+  %cst3 = llvm.mlir.constant(3 : i32) : i32
+  %cst4 = llvm.mlir.constant(4 : i32) : i32
+
+  // CHECK-LABEL: rocdl.mfma.scale.f32.16x16x128.f8f6f4
+  // fp8 * fp8
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 0, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r00 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg2, %arg1, %cst0, %cst0, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp8 * bf8
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 0, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r01 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg2, %arg1, %cst0, %cst1, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp8 * fp6
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v6i32(<8 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 0, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r02 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg3, %arg1, %cst0, %cst2, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp8 * bf6
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v6i32(<8 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 0, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r03 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg3, %arg1, %cst0, %cst3, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp8 * fp4
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v4i32(<8 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 0, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r04 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg4, %arg1, %cst0, %cst4, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // bf8 * fp8
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 1, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r10 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg2, %arg1, %cst1, %cst0, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // bf8 * bf8
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 1, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r11 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg2, %arg1, %cst1, %cst1, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // bf8 * fp6
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v6i32(<8 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 1, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r12 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg3, %arg1, %cst1, %cst2, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // bf8 * bf6
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v6i32(<8 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 1, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r13 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg3, %arg1, %cst1, %cst3, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // bf8 * fp4
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v4i32(<8 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 1, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r14 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg4, %arg1, %cst1, %cst4, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp6 * fp8
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v8i32(<6 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 2, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r20 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg2, %arg1, %cst2, %cst0, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp6 * bf8
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v8i32(<6 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 2, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r21 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg2, %arg1, %cst2, %cst1, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp6 * fp6
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v6i32(<6 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 2, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r22 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg3, %arg1, %cst2, %cst2, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp6 * bf6
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v6i32(<6 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 2, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r23 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg3, %arg1, %cst2, %cst3, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp6 * fp4
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v4i32(<6 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 2, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r24 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg4, %arg1, %cst2, %cst4, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // bf6 * fp8
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v8i32(<6 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 3, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r30 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg2, %arg1, %cst3, %cst0, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // bf6 * bf8
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v8i32(<6 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 3, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r31 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg2, %arg1, %cst3, %cst1, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // bf6 * fp6
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v6i32(<6 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 3, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r32 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg3, %arg1, %cst3, %cst2, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // bf6 * bf6
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v6i32(<6 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 3, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r33 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg3, %arg1, %cst3, %cst3, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // bf6 * fp4
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v4i32(<6 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 3, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r34 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg4, %arg1, %cst3, %cst4, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp4 * fp8
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v4i32.v8i32(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 4, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r40 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4, %arg2, %arg1, %cst4, %cst0, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp4 * bf8
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v4i32.v8i32(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 4, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r41 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4, %arg2, %arg1, %cst4, %cst1, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp4 * fp6
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v4i32.v6i32(<4 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 4, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r42 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4, %arg3, %arg1, %cst4, %cst2, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<4xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp4 * bf6
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v4i32.v6i32(<4 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 4, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r43 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4, %arg3, %arg1, %cst4, %cst3, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<4xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp4 * fp4
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v4i32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 4, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}}
+  %r44 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4, %arg4, %arg1, %cst4, %cst4, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  llvm.return %r00 : vector<4 x f32>
+}
+
 llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : vector<16 x i16>, %arg3 : vector<8 x i32>,
                       %arg4 : vector<2xi32>, %arg5 : vector<4xi32>, %arg6 : vector<4xf32>, %arg7 : vector<8xf16>, %arg8 : vector<8xi16>) -> vector<8xf32> {
   %zero = llvm.mlir.constant(false) : i1


        


More information about the Mlir-commits mailing list