[Mlir-commits] [mlir] [ROCDL] Refactored MFMA ops in ODS; added constraints (PR #175775)
Ravil Dorozhinskii
llvmlistbot at llvm.org
Wed Jan 14 06:08:48 PST 2026
https://github.com/ravil-mobile updated https://github.com/llvm/llvm-project/pull/175775
>From c171c9b388722b51e92ad45dca96115d5507cb65 Mon Sep 17 00:00:00 2001
From: ravil-mobile <ravil.aviva.com at gmail.com>
Date: Tue, 13 Jan 2026 15:08:51 +0000
Subject: [PATCH 1/6] [ROCDL] Refactored MFMA ops in ODS; added constraints
---
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 105 +++++++++++--------
mlir/test/Dialect/LLVMIR/rocdl.mlir | 2 +-
mlir/test/Target/LLVMIR/rocdl.mlir | 2 +-
3 files changed, 62 insertions(+), 47 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 265c2e99f52d6..ef070d26e2451 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -617,6 +617,21 @@ class ROCDL_Mfma_IntrOp<string mnemonic, list<Trait> traits = []> :
"$args attr-dict `:` functional-type($args, $res)";
}
+class ROCDL_Mfma_IntrOpV0<string mnemonic, Type AB, Type CD> :
+ ROCDL_IntrOp<mnemonic, [], [], [], 1, 0, 0, 0, [], []>,
+ Arguments<(ins
+ LLVM_ScalarOrVectorOf<AB>:$a,
+ LLVM_ScalarOrVectorOf<AB>:$b,
+ LLVM_ScalarOrVectorOf<CD>:$c,
+ I32:$cbsz,
+ I32:$abid,
+ I32:$blgp)> {
+ let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+ let assemblyFormat = [{
+ $a `,` $b `,` $c `,` $cbsz `,` $abid `,` $blgp attr-dict `:` functional-type(operands, $res)
+ }];
+}
+
//===---------------------------------------------------------------------===//
// MFMA intrinsics with overloaded operands
class ROCDL_Mfma_OO_IntrOp<string mnemonic, list<int> overloadedOperands,
@@ -628,56 +643,56 @@ class ROCDL_Mfma_OO_IntrOp<string mnemonic, list<int> overloadedOperands,
}
// Available on all CDNA.
-def ROCDL_mfma_f32_32x32x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x1f32">;
-def ROCDL_mfma_f32_16x16x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x1f32">;
-def ROCDL_mfma_f32_4x4x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x1f32">;
-def ROCDL_mfma_f32_32x32x2f32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x2f32">;
-def ROCDL_mfma_f32_16x16x4f32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x4f32">;
-def ROCDL_mfma_f32_32x32x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4f16">;
-def ROCDL_mfma_f32_16x16x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x4f16">;
-def ROCDL_mfma_f32_4x4x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x4f16">;
-def ROCDL_mfma_f32_32x32x8f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x8f16">;
-def ROCDL_mfma_f32_16x16x16f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x16f16">;
-def ROCDL_mfma_i32_32x32x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x4i8">;
-def ROCDL_mfma_i32_16x16x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x4i8">;
-def ROCDL_mfma_i32_4x4x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.4x4x4i8">;
-def ROCDL_mfma_i32_32x32x8i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x8i8">;
-def ROCDL_mfma_i32_16x16x16i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x16i8">;
-def ROCDL_mfma_f32_32x32x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x2bf16">;
-def ROCDL_mfma_f32_16x16x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x2bf16">;
-def ROCDL_mfma_f32_4x4x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x2bf16">;
-def ROCDL_mfma_f32_32x32x4bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4bf16">;
-def ROCDL_mfma_f32_16x16x8bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x8bf16">;
+def ROCDL_mfma_f32_32x32x1f32 : ROCDL_Mfma_IntrOpV0<"mfma.f32.32x32x1f32", /*Type AB=*/F32, /*Type CD=*/F32>;
+def ROCDL_mfma_f32_16x16x1f32 : ROCDL_Mfma_IntrOpV0<"mfma.f32.16x16x1f32", F32, F32>;
+def ROCDL_mfma_f32_4x4x1f32 : ROCDL_Mfma_IntrOpV0<"mfma.f32.4x4x1f32", F32, F32>;
+def ROCDL_mfma_f32_32x32x2f32 : ROCDL_Mfma_IntrOpV0<"mfma.f32.32x32x2f32", F32, F32>;
+def ROCDL_mfma_f32_16x16x4f32 : ROCDL_Mfma_IntrOpV0<"mfma.f32.16x16x4f32", F32, F32>;
+def ROCDL_mfma_f32_32x32x4f16 : ROCDL_Mfma_IntrOpV0<"mfma.f32.32x32x4f16", F16, F32>;
+def ROCDL_mfma_f32_16x16x4f16 : ROCDL_Mfma_IntrOpV0<"mfma.f32.16x16x4f16", F16, F32>;
+def ROCDL_mfma_f32_4x4x4f16 : ROCDL_Mfma_IntrOpV0<"mfma.f32.4x4x4f16", F16, F32>;
+def ROCDL_mfma_f32_32x32x8f16 : ROCDL_Mfma_IntrOpV0<"mfma.f32.32x32x8f16", F16, F32>;
+def ROCDL_mfma_f32_16x16x16f16 : ROCDL_Mfma_IntrOpV0<"mfma.f32.16x16x16f16", F16, F32>;
+def ROCDL_mfma_i32_32x32x4i8 : ROCDL_Mfma_IntrOpV0<"mfma.i32.32x32x4i8", I32, I32>;
+def ROCDL_mfma_i32_16x16x4i8 : ROCDL_Mfma_IntrOpV0<"mfma.i32.16x16x4i8", I32, I32>;
+def ROCDL_mfma_i32_4x4x4i8 : ROCDL_Mfma_IntrOpV0<"mfma.i32.4x4x4i8", I32, I32>;
+def ROCDL_mfma_i32_32x32x8i8 : ROCDL_Mfma_IntrOpV0<"mfma.i32.32x32x8i8", I32, I32>;
+def ROCDL_mfma_i32_16x16x16i8 : ROCDL_Mfma_IntrOpV0<"mfma.i32.16x16x16i8", I32, I32>;
+def ROCDL_mfma_f32_32x32x2bf16 : ROCDL_Mfma_IntrOpV0<"mfma.f32.32x32x2bf16", I16, F32>;
+def ROCDL_mfma_f32_16x16x2bf16 : ROCDL_Mfma_IntrOpV0<"mfma.f32.16x16x2bf16", I16, F32>;
+def ROCDL_mfma_f32_4x4x2bf16 : ROCDL_Mfma_IntrOpV0<"mfma.f32.4x4x2bf16", I16, F32>;
+def ROCDL_mfma_f32_32x32x4bf16 : ROCDL_Mfma_IntrOpV0<"mfma.f32.32x32x4bf16", I16, F32>;
+def ROCDL_mfma_f32_16x16x8bf16 : ROCDL_Mfma_IntrOpV0<"mfma.f32.16x16x8bf16", I16, F32>;
// New in gfx90a.
-def ROCDL_mfma_f32_32x32x4bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4bf16.1k">;
-def ROCDL_mfma_f32_16x16x4bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x4bf16.1k">;
-def ROCDL_mfma_f32_4x4x4bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x4bf16.1k">;
-def ROCDL_mfma_f32_32x32x8bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x8bf16.1k">;
-def ROCDL_mfma_f32_16x16x16bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x16bf16.1k">;
+def ROCDL_mfma_f32_32x32x4bf16_1k : ROCDL_Mfma_IntrOpV0<"mfma.f32.32x32x4bf16.1k", I16, F32>;
+def ROCDL_mfma_f32_16x16x4bf16_1k : ROCDL_Mfma_IntrOpV0<"mfma.f32.16x16x4bf16.1k", I16, F32>;
+def ROCDL_mfma_f32_4x4x4bf16_1k : ROCDL_Mfma_IntrOpV0<"mfma.f32.4x4x4bf16.1k", I16, F32>;
+def ROCDL_mfma_f32_32x32x8bf16_1k : ROCDL_Mfma_IntrOpV0<"mfma.f32.32x32x8bf16.1k", I16, F32>;
+def ROCDL_mfma_f32_16x16x16bf16_1k : ROCDL_Mfma_IntrOpV0<"mfma.f32.16x16x16bf16.1k", I16, F32>;
// Note: in gfx94x, unlike in gfx90a, the f64 xdlops use the "blgp" argument as
// a NEG bitfield. See IntrinsicsAMDGPU.td for more info.
-def ROCDL_mfma_f64_16x16x4f64 : ROCDL_Mfma_IntrOp<"mfma.f64.16x16x4f64">;
-def ROCDL_mfma_f64_4x4x4f64 : ROCDL_Mfma_IntrOp<"mfma.f64.4x4x4f64">;
+def ROCDL_mfma_f64_16x16x4f64 : ROCDL_Mfma_IntrOpV0<"mfma.f64.16x16x4f64", F64, F64>;
+def ROCDL_mfma_f64_4x4x4f64 : ROCDL_Mfma_IntrOpV0<"mfma.f64.4x4x4f64", F64, F64>;
// New in gfx94x.
-def ROCDL_mfma_i32_16x16x32_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x32.i8">;
-def ROCDL_mfma_i32_32x32x16_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x16.i8">;
-def ROCDL_mfma_f32_16x16x8_xf32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x8.xf32">;
-def ROCDL_mfma_f32_32x32x4_xf32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4.xf32">;
-def ROCDL_mfma_f32_16x16x32_bf8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.bf8.bf8">;
-def ROCDL_mfma_f32_16x16x32_bf8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.bf8.fp8">;
-def ROCDL_mfma_f32_16x16x32_fp8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.fp8.bf8">;
-def ROCDL_mfma_f32_16x16x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.fp8.fp8">;
-def ROCDL_mfma_f32_32x32x16_bf8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf8.bf8">;
-def ROCDL_mfma_f32_32x32x16_bf8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf8.fp8">;
-def ROCDL_mfma_f32_32x32x16_fp8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.fp8.bf8">;
-def ROCDL_mfma_f32_32x32x16_fp8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.fp8.fp8">;
+def ROCDL_mfma_i32_16x16x32_i8 : ROCDL_Mfma_IntrOpV0<"mfma.i32.16x16x32.i8", I64, I32>;
+def ROCDL_mfma_i32_32x32x16_i8 : ROCDL_Mfma_IntrOpV0<"mfma.i32.32x32x16.i8", I64, I32>;
+def ROCDL_mfma_f32_16x16x8_xf32 : ROCDL_Mfma_IntrOpV0<"mfma.f32.16x16x8.xf32", F32, F32>;
+def ROCDL_mfma_f32_32x32x4_xf32 : ROCDL_Mfma_IntrOpV0<"mfma.f32.32x32x4.xf32", F32, F32>;
+def ROCDL_mfma_f32_16x16x32_bf8_bf8 : ROCDL_Mfma_IntrOpV0<"mfma.f32.16x16x32.bf8.bf8", I64, F32>;
+def ROCDL_mfma_f32_16x16x32_bf8_fp8 : ROCDL_Mfma_IntrOpV0<"mfma.f32.16x16x32.bf8.fp8", I64, F32>;
+def ROCDL_mfma_f32_16x16x32_fp8_bf8 : ROCDL_Mfma_IntrOpV0<"mfma.f32.16x16x32.fp8.bf8", I64, F32>;
+def ROCDL_mfma_f32_16x16x32_fp8_fp8 : ROCDL_Mfma_IntrOpV0<"mfma.f32.16x16x32.fp8.fp8", I64, F32>;
+def ROCDL_mfma_f32_32x32x16_bf8_bf8 : ROCDL_Mfma_IntrOpV0<"mfma.f32.32x32x16.bf8.bf8", I64, F32>;
+def ROCDL_mfma_f32_32x32x16_bf8_fp8 : ROCDL_Mfma_IntrOpV0<"mfma.f32.32x32x16.bf8.fp8", I64, F32>;
+def ROCDL_mfma_f32_32x32x16_fp8_bf8 : ROCDL_Mfma_IntrOpV0<"mfma.f32.32x32x16.fp8.bf8", I64, F32>;
+def ROCDL_mfma_f32_32x32x16_fp8_fp8 : ROCDL_Mfma_IntrOpV0<"mfma.f32.32x32x16.fp8.fp8", I64, F32>;
// New in gfx950.
-def ROCDL_mfma_f32_16x16x32_bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.bf16">;
-def ROCDL_mfma_i32_16x16x64_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x64.i8">;
-def ROCDL_mfma_f32_16x16x32_f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.f16">;
-def ROCDL_mfma_f32_32x32x16_bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf16">;
-def ROCDL_mfma_i32_32x32x32_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x32.i8">;
-def ROCDL_mfma_f32_32x32x16_f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.f16">;
+def ROCDL_mfma_f32_16x16x32_bf16 : ROCDL_Mfma_IntrOpV0<"mfma.f32.16x16x32.bf16", BF16, F32>;
+def ROCDL_mfma_i32_16x16x64_i8 : ROCDL_Mfma_IntrOpV0<"mfma.i32.16x16x64.i8", I32, I32>;
+def ROCDL_mfma_f32_16x16x32_f16 : ROCDL_Mfma_IntrOpV0<"mfma.f32.16x16x32.f16", F16, F32>;
+def ROCDL_mfma_f32_32x32x16_bf16 : ROCDL_Mfma_IntrOpV0<"mfma.f32.32x32x16.bf16", BF16, F32>;
+def ROCDL_mfma_i32_32x32x32_i8 : ROCDL_Mfma_IntrOpV0<"mfma.i32.32x32x32.i8", I32, I32>;
+def ROCDL_mfma_f32_32x32x16_f16 : ROCDL_Mfma_IntrOpV0<"mfma.f32.32x32x16.f16", F16, F32>;
def ROCDL_mfma_scale_f32_16x16x128_f8f6f4 : ROCDL_Mfma_OO_IntrOp<"mfma.scale.f32.16x16x128.f8f6f4", [0,1]>;
def ROCDL_mfma_scale_f32_32x32x64_f8f6f4 : ROCDL_Mfma_OO_IntrOp<"mfma.scale.f32.32x32x64.f8f6f4", [0,1]>;
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index cf2b144219f36..daa46b72057ca 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -311,7 +311,7 @@ func.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
// CHECK: rocdl.mfma.f32.16x16x32.f16 {{.*}} : (vector<8xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xi32>
%r33 = rocdl.mfma.f32.16x16x32.f16 %arg17, %arg17, %arg5, %arg3, %arg3, %arg3 :
(vector<8xf16>, vector<8xf16>, vector<4xf32>,
- i32, i32, i32) -> vector<4xi32>
+ i32, i32, i32) -> vector<4xf32>
// CHECK: rocdl.mfma.f32.32x32x16.bf16 {{.*}} : (vector<8xbf16>, vector<8xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
%r34 = rocdl.mfma.f32.32x32x16.bf16 %arg16, %arg16, %arg4, %arg3, %arg3, %arg3 :
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index dc6a00e19afc3..af77f09c091d5 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -566,7 +566,7 @@ llvm.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
// CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x32.f16(<8 x half> %{{.*}}, <8 x half> %{{.*}}, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
%r30 = rocdl.mfma.f32.16x16x32.f16 %arg14, %arg14, %arg5, %csti32, %csti32, %csti32 :
(vector<8xf16>, vector<8xf16>, vector<4xf32>,
- i32, i32, i32) -> vector<4xi32>
+ i32, i32, i32) -> vector<4xf32>
// CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x16.bf16(<8 x bfloat> %1{{.*}}, <8 x bfloat> %{{.*}}, <16 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
%r31 = rocdl.mfma.f32.32x32x16.bf16 %arg12, %arg12, %arg4, %csti32, %csti32, %csti32 :
>From 4bcc323599146efbb705876b485701e4c3001fca Mon Sep 17 00:00:00 2001
From: ravil-mobile <ravil.aviva.com at gmail.com>
Date: Tue, 13 Jan 2026 15:50:03 +0000
Subject: [PATCH 2/6] [ROCDL] Promoted attributes for MFMA ops
---
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 8 +-
mlir/test/Dialect/LLVMIR/rocdl.mlir | 333 +++++++++----------
mlir/test/Target/LLVMIR/rocdl.mlir | 171 ++++------
3 files changed, 220 insertions(+), 292 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index ef070d26e2451..2add87da53e81 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -618,14 +618,14 @@ class ROCDL_Mfma_IntrOp<string mnemonic, list<Trait> traits = []> :
}
class ROCDL_Mfma_IntrOpV0<string mnemonic, Type AB, Type CD> :
- ROCDL_IntrOp<mnemonic, [], [], [], 1, 0, 0, 0, [], []>,
+ ROCDL_IntrOp<mnemonic, [], [], [], 1, 0, 0, 0, [3, 4, 5], ["cbsz", "abid", "blgp"]>,
Arguments<(ins
LLVM_ScalarOrVectorOf<AB>:$a,
LLVM_ScalarOrVectorOf<AB>:$b,
LLVM_ScalarOrVectorOf<CD>:$c,
- I32:$cbsz,
- I32:$abid,
- I32:$blgp)> {
+ I32Attr:$cbsz,
+ I32Attr:$abid,
+ I32Attr:$blgp)> {
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
let assemblyFormat = [{
$a `,` $b `,` $c `,` $cbsz `,` $abid `,` $blgp attr-dict `:` functional-type(operands, $res)
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index daa46b72057ca..47eb2944f18f7 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -142,191 +142,154 @@ func.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
%arg14 : i64, %arg15 : vector<2xf32>,
%arg16: vector<8xbf16>, %arg17 : vector<8xf16>) {
// CHECK-LABEL: rocdl.xdlops
- // CHECK: rocdl.mfma.f32.32x32x1f32 {{.*}} : (f32, f32, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
- %r0 = rocdl.mfma.f32.32x32x1f32 %arg0, %arg1, %arg2, %arg3, %arg3, %arg3 :
- (f32, f32, vector<32xf32>,
- i32, i32, i32) -> vector<32xf32>
-
- // CHECK: rocdl.mfma.f32.16x16x1f32 {{.*}} : (f32, f32, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- %r1 = rocdl.mfma.f32.16x16x1f32 %arg0, %arg1, %arg4, %arg3, %arg3, %arg3 :
- (f32, f32, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: rocdl.mfma.f32.4x4x1f32 {{.*}} : (f32, f32, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- %r2 = rocdl.mfma.f32.4x4x1f32 %arg0, %arg1, %arg5, %arg3, %arg3, %arg3 :
- (f32, f32, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: rocdl.mfma.f32.32x32x2f32 {{.*}} : (f32, f32, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- %r3= rocdl.mfma.f32.32x32x2f32 %arg0, %arg1, %arg4, %arg3, %arg3, %arg3 :
- (f32, f32, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: rocdl.mfma.f32.16x16x4f32 {{.*}} : (f32, f32, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- %r4 = rocdl.mfma.f32.16x16x4f32 %arg0, %arg1, %arg5, %arg3, %arg3, %arg3 :
- (f32, f32, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: rocdl.mfma.f32.32x32x4f16 {{.*}} : (vector<4xf16>, vector<4xf16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
- %r5 = rocdl.mfma.f32.32x32x4f16 %arg6, %arg6, %arg2, %arg3, %arg3, %arg3 :
- (vector<4xf16>, vector<4xf16>, vector<32xf32>,
- i32, i32, i32) -> vector<32xf32>
-
- // CHECK: rocdl.mfma.f32.16x16x4f16 {{.*}} : (vector<4xf16>, vector<4xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- %r6 = rocdl.mfma.f32.16x16x4f16 %arg6, %arg6, %arg4, %arg3, %arg3, %arg3 :
- (vector<4xf16>, vector<4xf16>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: rocdl.mfma.f32.4x4x4f16 {{.*}} : (vector<4xf16>, vector<4xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- %r7 = rocdl.mfma.f32.4x4x4f16 %arg6, %arg6, %arg5, %arg3, %arg3, %arg3 :
- (vector<4xf16>, vector<4xf16>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: rocdl.mfma.f32.32x32x8f16 {{.*}} : (vector<4xf16>, vector<4xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- %r8 = rocdl.mfma.f32.32x32x8f16 %arg6, %arg6, %arg4, %arg3, %arg3, %arg3 :
- (vector<4xf16>, vector<4xf16>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: rocdl.mfma.f32.16x16x16f16 {{.*}} : (vector<4xf16>, vector<4xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- %r9 = rocdl.mfma.f32.16x16x16f16 %arg6, %arg6, %arg5, %arg3, %arg3, %arg3 :
- (vector<4xf16>, vector<4xf16>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: rocdl.mfma.i32.32x32x4i8 {{.*}} : (i32, i32, vector<32xi32>, i32, i32, i32) -> vector<32xi32>
- %r10 = rocdl.mfma.i32.32x32x4i8 %arg3, %arg3, %arg7, %arg3, %arg3, %arg3 :
- (i32, i32, vector<32xi32>,
- i32, i32, i32) -> vector<32xi32>
-
- // CHECK: rocdl.mfma.i32.16x16x4i8 {{.*}} : (i32, i32, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
- %r11 = rocdl.mfma.i32.16x16x4i8 %arg3, %arg3, %arg8, %arg3, %arg3, %arg3 :
- (i32, i32, vector<16xi32>,
- i32, i32, i32) -> vector<16xi32>
-
- // CHECK: rocdl.mfma.i32.4x4x4i8 {{.*}} : (i32, i32, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
- %r12 = rocdl.mfma.i32.4x4x4i8 %arg3, %arg3, %arg9, %arg3, %arg3, %arg3 :
- (i32, i32, vector<4xi32>,
- i32, i32, i32) -> vector<4xi32>
-
- // CHECK: rocdl.mfma.i32.32x32x8i8 {{.*}} : (i32, i32, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
- %r13 = rocdl.mfma.i32.32x32x8i8 %arg3, %arg3, %arg8, %arg3, %arg3, %arg3 :
- (i32, i32, vector<16xi32>,
- i32, i32, i32) -> vector<16xi32>
-
- // CHECK: rocdl.mfma.i32.16x16x16i8 {{.*}} : (i32, i32, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
- %r14 = rocdl.mfma.i32.16x16x16i8 %arg3, %arg3, %arg9, %arg3, %arg3, %arg3 :
- (i32, i32, vector<4xi32>,
- i32, i32, i32) -> vector<4xi32>
-
- // CHECK: rocdl.mfma.f32.32x32x2bf16 {{.*}} : (vector<2xi16>, vector<2xi16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
- %r15 = rocdl.mfma.f32.32x32x2bf16 %arg10, %arg10, %arg2, %arg3, %arg3, %arg3 :
- (vector<2xi16>, vector<2xi16>, vector<32xf32>,
- i32, i32, i32) -> vector<32xf32>
-
- // CHECK: rocdl.mfma.f32.16x16x2bf16 {{.*}} : (vector<2xi16>, vector<2xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- %r16 = rocdl.mfma.f32.16x16x2bf16 %arg10, %arg10, %arg4, %arg3, %arg3, %arg3 :
- (vector<2xi16>, vector<2xi16>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: rocdl.mfma.f32.4x4x2bf16 {{.*}} : (vector<2xi16>, vector<2xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- %r17 = rocdl.mfma.f32.4x4x2bf16 %arg10, %arg10, %arg5, %arg3, %arg3, %arg3 :
- (vector<2xi16>, vector<2xi16>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: rocdl.mfma.f32.32x32x4bf16 {{.*}} : (vector<2xi16>, vector<2xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- %r18 = rocdl.mfma.f32.32x32x4bf16 %arg10, %arg10, %arg4, %arg3, %arg3, %arg3 :
- (vector<2xi16>, vector<2xi16>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: rocdl.mfma.f32.16x16x8bf16 {{.*}} : (vector<2xi16>, vector<2xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- %r19 = rocdl.mfma.f32.16x16x8bf16 %arg10, %arg10, %arg5, %arg3, %arg3, %arg3 :
- (vector<2xi16>, vector<2xi16>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
-
- // CHECK: rocdl.mfma.f32.32x32x4bf16.1k {{.*}} : (vector<4xi16>, vector<4xi16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
- %r20 = rocdl.mfma.f32.32x32x4bf16.1k %arg11, %arg11, %arg2, %arg3, %arg3, %arg3 :
- (vector<4xi16>, vector<4xi16>, vector<32xf32>,
- i32, i32, i32) -> vector<32xf32>
-
- // CHECK: rocdl.mfma.f32.16x16x4bf16.1k {{.*}} : (vector<4xi16>, vector<4xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- %r21 = rocdl.mfma.f32.16x16x4bf16.1k %arg11, %arg11, %arg4, %arg3, %arg3, %arg3 :
- (vector<4xi16>, vector<4xi16>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: rocdl.mfma.f32.4x4x4bf16.1k {{.*}} : (vector<4xi16>, vector<4xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- %r22 = rocdl.mfma.f32.4x4x4bf16.1k %arg11, %arg11, %arg5, %arg3, %arg3, %arg3 :
- (vector<4xi16>, vector<4xi16>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: rocdl.mfma.f32.32x32x8bf16.1k {{.*}} : (vector<4xi16>, vector<4xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- %r23 = rocdl.mfma.f32.32x32x8bf16.1k %arg11, %arg11, %arg4, %arg3, %arg3, %arg3 :
- (vector<4xi16>, vector<4xi16>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: rocdl.mfma.f32.16x16x16bf16.1k {{.*}} : (vector<4xi16>, vector<4xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- %r24 = rocdl.mfma.f32.16x16x16bf16.1k %arg11, %arg11, %arg5, %arg3, %arg3, %arg3 :
- (vector<4xi16>, vector<4xi16>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: rocdl.mfma.f64.16x16x4f64 {{.*}} : (f64, f64, vector<4xf64>, i32, i32, i32) -> vector<4xf64>
- %r25 = rocdl.mfma.f64.16x16x4f64 %arg13, %arg13, %arg12, %arg3, %arg3, %arg3 :
- (f64, f64, vector<4xf64>,
- i32, i32, i32) -> vector<4xf64>
-
- // CHECK: rocdl.mfma.f64.4x4x4f64 {{.*}} : (f64, f64, f64, i32, i32, i32) -> f64
- %r26 = rocdl.mfma.f64.4x4x4f64 %arg13, %arg13, %arg13, %arg3, %arg3, %arg3 :
- (f64, f64, f64,
- i32, i32, i32) -> f64
-
- // CHECK: rocdl.mfma.i32.16x16x32.i8 {{.*}} : (i64, i64, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
- %r27 = rocdl.mfma.i32.16x16x32.i8 %arg14, %arg14, %arg9, %arg3, %arg3, %arg3 :
- (i64, i64, vector<4xi32>,
- i32, i32, i32) -> vector<4xi32>
-
- // CHECK: rocdl.mfma.i32.32x32x16.i8 {{.*}} : (i64, i64, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
- %r28 = rocdl.mfma.i32.32x32x16.i8 %arg14, %arg14, %arg8, %arg3, %arg3, %arg3 :
- (i64, i64, vector<16xi32>,
- i32, i32, i32) -> vector<16xi32>
-
- // CHECK: rocdl.mfma.f32.16x16x8.xf32 {{.*}} : (vector<2xf32>, vector<2xf32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- %r29 = rocdl.mfma.f32.16x16x8.xf32 %arg15, %arg15, %arg5, %arg3, %arg3, %arg3 :
- (vector<2xf32>, vector<2xf32>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: rocdl.mfma.f32.32x32x4.xf32 {{.*}} : (vector<2xf32>, vector<2xf32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- %r30 = rocdl.mfma.f32.32x32x4.xf32 %arg15, %arg15, %arg4, %arg3, %arg3, %arg3 :
- (vector<2xf32>, vector<2xf32>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: rocdl.mfma.f32.16x16x32.bf16 {{.*}} : (vector<8xbf16>, vector<8xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- %r31 = rocdl.mfma.f32.16x16x32.bf16 %arg16, %arg16, %arg5, %arg3, %arg3, %arg3 :
- (vector<8xbf16>, vector<8xbf16>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: rocdl.mfma.i32.16x16x64.i8 {{.*}} : (vector<4xi32>, vector<4xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
- %r32 = rocdl.mfma.i32.16x16x64.i8 %arg9, %arg9, %arg9, %arg3, %arg3, %arg3 :
- (vector<4xi32>, vector<4xi32>, vector<4xi32>,
- i32, i32, i32) -> vector<4xi32>
-
- // CHECK: rocdl.mfma.f32.16x16x32.f16 {{.*}} : (vector<8xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xi32>
- %r33 = rocdl.mfma.f32.16x16x32.f16 %arg17, %arg17, %arg5, %arg3, %arg3, %arg3 :
- (vector<8xf16>, vector<8xf16>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: rocdl.mfma.f32.32x32x16.bf16 {{.*}} : (vector<8xbf16>, vector<8xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- %r34 = rocdl.mfma.f32.32x32x16.bf16 %arg16, %arg16, %arg4, %arg3, %arg3, %arg3 :
- (vector<8xbf16>, vector<8xbf16>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: rocdl.mfma.i32.32x32x32.i8 {{.*}} : (vector<4xi32>, vector<4xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
- %r35 = rocdl.mfma.i32.32x32x32.i8 %arg9, %arg9, %arg8, %arg3, %arg3, %arg3 :
- (vector<4xi32>, vector<4xi32>, vector<16xi32>,
- i32, i32, i32) -> vector<16xi32>
-
- // CHECK: rocdl.mfma.f32.32x32x16.f16 {{.*}} : (vector<8xf16>, vector<8xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- %r36 = rocdl.mfma.f32.32x32x16.f16 %arg17, %arg17, %arg4, %arg3, %arg3, %arg3 :
- (vector<8xf16>, vector<8xf16>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.f32.32x32x1f32 {{.*}} : (f32, f32, vector<32xf32>) -> vector<32xf32>
+ %r0 = rocdl.mfma.f32.32x32x1f32 %arg0, %arg1, %arg2, 0, 0, 0 :
+ (f32, f32, vector<32xf32>) -> vector<32xf32>
+
+ // CHECK: rocdl.mfma.f32.16x16x1f32 {{.*}} : (f32, f32, vector<16xf32>) -> vector<16xf32>
+ %r1 = rocdl.mfma.f32.16x16x1f32 %arg0, %arg1, %arg4, 0, 0, 0 :
+ (f32, f32, vector<16xf32>) -> vector<16xf32>
+
+ // CHECK: rocdl.mfma.f32.4x4x1f32 {{.*}} : (f32, f32, vector<4xf32>) -> vector<4xf32>
+ %r2 = rocdl.mfma.f32.4x4x1f32 %arg0, %arg1, %arg5, 0, 0, 0 :
+ (f32, f32, vector<4xf32>) -> vector<4xf32>
+
+ // CHECK: rocdl.mfma.f32.32x32x2f32 {{.*}} : (f32, f32, vector<16xf32>) -> vector<16xf32>
+ %r3= rocdl.mfma.f32.32x32x2f32 %arg0, %arg1, %arg4, 0, 0, 0 :
+ (f32, f32, vector<16xf32>) -> vector<16xf32>
+
+ // CHECK: rocdl.mfma.f32.16x16x4f32 {{.*}} : (f32, f32, vector<4xf32>) -> vector<4xf32>
+ %r4 = rocdl.mfma.f32.16x16x4f32 %arg0, %arg1, %arg5, 0, 0, 0 :
+ (f32, f32, vector<4xf32>) -> vector<4xf32>
+
+ // CHECK: rocdl.mfma.f32.32x32x4f16 {{.*}} : (vector<4xf16>, vector<4xf16>, vector<32xf32>) -> vector<32xf32>
+ %r5 = rocdl.mfma.f32.32x32x4f16 %arg6, %arg6, %arg2, 0, 0, 0 :
+ (vector<4xf16>, vector<4xf16>, vector<32xf32>) -> vector<32xf32>
+
+ // CHECK: rocdl.mfma.f32.16x16x4f16 {{.*}} : (vector<4xf16>, vector<4xf16>, vector<16xf32>) -> vector<16xf32>
+ %r6 = rocdl.mfma.f32.16x16x4f16 %arg6, %arg6, %arg4, 0, 0, 0 :
+ (vector<4xf16>, vector<4xf16>, vector<16xf32>) -> vector<16xf32>
+
+ // CHECK: rocdl.mfma.f32.4x4x4f16 {{.*}} : (vector<4xf16>, vector<4xf16>, vector<4xf32>) -> vector<4xf32>
+ %r7 = rocdl.mfma.f32.4x4x4f16 %arg6, %arg6, %arg5, 0, 0, 0 :
+ (vector<4xf16>, vector<4xf16>, vector<4xf32>) -> vector<4xf32>
+
+ // CHECK: rocdl.mfma.f32.32x32x8f16 {{.*}} : (vector<4xf16>, vector<4xf16>, vector<16xf32>) -> vector<16xf32>
+ %r8 = rocdl.mfma.f32.32x32x8f16 %arg6, %arg6, %arg4, 0, 0, 0 :
+ (vector<4xf16>, vector<4xf16>, vector<16xf32>) -> vector<16xf32>
+
+ // CHECK: rocdl.mfma.f32.16x16x16f16 {{.*}} : (vector<4xf16>, vector<4xf16>, vector<4xf32>) -> vector<4xf32>
+ %r9 = rocdl.mfma.f32.16x16x16f16 %arg6, %arg6, %arg5, 0, 0, 0 :
+ (vector<4xf16>, vector<4xf16>, vector<4xf32>) -> vector<4xf32>
+
+ // CHECK: rocdl.mfma.i32.32x32x4i8 {{.*}} : (i32, i32, vector<32xi32>) -> vector<32xi32>
+ %r10 = rocdl.mfma.i32.32x32x4i8 %arg3, %arg3, %arg7, 0, 0, 0 :
+ (i32, i32, vector<32xi32>) -> vector<32xi32>
+
+ // CHECK: rocdl.mfma.i32.16x16x4i8 {{.*}} : (i32, i32, vector<16xi32>) -> vector<16xi32>
+ %r11 = rocdl.mfma.i32.16x16x4i8 %arg3, %arg3, %arg8, 0, 0, 0 :
+ (i32, i32, vector<16xi32>) -> vector<16xi32>
+
+ // CHECK: rocdl.mfma.i32.4x4x4i8 {{.*}} : (i32, i32, vector<4xi32>) -> vector<4xi32>
+ %r12 = rocdl.mfma.i32.4x4x4i8 %arg3, %arg3, %arg9, 0, 0, 0 :
+ (i32, i32, vector<4xi32>) -> vector<4xi32>
+
+ // CHECK: rocdl.mfma.i32.32x32x8i8 {{.*}} : (i32, i32, vector<16xi32>) -> vector<16xi32>
+ %r13 = rocdl.mfma.i32.32x32x8i8 %arg3, %arg3, %arg8, 0, 0, 0 :
+ (i32, i32, vector<16xi32>) -> vector<16xi32>
+
+ // CHECK: rocdl.mfma.i32.16x16x16i8 {{.*}} : (i32, i32, vector<4xi32>) -> vector<4xi32>
+ %r14 = rocdl.mfma.i32.16x16x16i8 %arg3, %arg3, %arg9, 0, 0, 0 :
+ (i32, i32, vector<4xi32>) -> vector<4xi32>
+
+ // CHECK: rocdl.mfma.f32.32x32x2bf16 {{.*}} : (vector<2xi16>, vector<2xi16>, vector<32xf32>) -> vector<32xf32>
+ %r15 = rocdl.mfma.f32.32x32x2bf16 %arg10, %arg10, %arg2, 0, 0, 0 :
+ (vector<2xi16>, vector<2xi16>, vector<32xf32>) -> vector<32xf32>
+
+ // CHECK: rocdl.mfma.f32.16x16x2bf16 {{.*}} : (vector<2xi16>, vector<2xi16>, vector<16xf32>) -> vector<16xf32>
+ %r16 = rocdl.mfma.f32.16x16x2bf16 %arg10, %arg10, %arg4, 0, 0, 0 :
+ (vector<2xi16>, vector<2xi16>, vector<16xf32>) -> vector<16xf32>
+
+ // CHECK: rocdl.mfma.f32.4x4x2bf16 {{.*}} : (vector<2xi16>, vector<2xi16>, vector<4xf32>) -> vector<4xf32>
+ %r17 = rocdl.mfma.f32.4x4x2bf16 %arg10, %arg10, %arg5, 0, 0, 0 :
+ (vector<2xi16>, vector<2xi16>, vector<4xf32>) -> vector<4xf32>
+
+ // CHECK: rocdl.mfma.f32.32x32x4bf16 {{.*}} : (vector<2xi16>, vector<2xi16>, vector<16xf32>) -> vector<16xf32>
+ %r18 = rocdl.mfma.f32.32x32x4bf16 %arg10, %arg10, %arg4, 0, 0, 0 :
+ (vector<2xi16>, vector<2xi16>, vector<16xf32>) -> vector<16xf32>
+
+ // CHECK: rocdl.mfma.f32.16x16x8bf16 {{.*}} : (vector<2xi16>, vector<2xi16>, vector<4xf32>) -> vector<4xf32>
+ %r19 = rocdl.mfma.f32.16x16x8bf16 %arg10, %arg10, %arg5, 0, 0, 0 :
+ (vector<2xi16>, vector<2xi16>, vector<4xf32>) -> vector<4xf32>
+
+
+ // CHECK: rocdl.mfma.f32.32x32x4bf16.1k {{.*}} : (vector<4xi16>, vector<4xi16>, vector<32xf32>) -> vector<32xf32>
+ %r20 = rocdl.mfma.f32.32x32x4bf16.1k %arg11, %arg11, %arg2, 0, 0, 0 :
+ (vector<4xi16>, vector<4xi16>, vector<32xf32>) -> vector<32xf32>
+
+ // CHECK: rocdl.mfma.f32.16x16x4bf16.1k {{.*}} : (vector<4xi16>, vector<4xi16>, vector<16xf32>) -> vector<16xf32>
+ %r21 = rocdl.mfma.f32.16x16x4bf16.1k %arg11, %arg11, %arg4, 0, 0, 0 :
+ (vector<4xi16>, vector<4xi16>, vector<16xf32>) -> vector<16xf32>
+
+ // CHECK: rocdl.mfma.f32.4x4x4bf16.1k {{.*}} : (vector<4xi16>, vector<4xi16>, vector<4xf32>) -> vector<4xf32>
+ %r22 = rocdl.mfma.f32.4x4x4bf16.1k %arg11, %arg11, %arg5, 0, 0, 0 :
+ (vector<4xi16>, vector<4xi16>, vector<4xf32>) -> vector<4xf32>
+
+ // CHECK: rocdl.mfma.f32.32x32x8bf16.1k {{.*}} : (vector<4xi16>, vector<4xi16>, vector<16xf32>) -> vector<16xf32>
+ %r23 = rocdl.mfma.f32.32x32x8bf16.1k %arg11, %arg11, %arg4, 0, 0, 0 :
+ (vector<4xi16>, vector<4xi16>, vector<16xf32>) -> vector<16xf32>
+
+ // CHECK: rocdl.mfma.f32.16x16x16bf16.1k {{.*}} : (vector<4xi16>, vector<4xi16>, vector<4xf32>) -> vector<4xf32>
+ %r24 = rocdl.mfma.f32.16x16x16bf16.1k %arg11, %arg11, %arg5, 0, 0, 0 :
+ (vector<4xi16>, vector<4xi16>, vector<4xf32>) -> vector<4xf32>
+
+ // CHECK: rocdl.mfma.f64.16x16x4f64 {{.*}} : (f64, f64, vector<4xf64>) -> vector<4xf64>
+ %r25 = rocdl.mfma.f64.16x16x4f64 %arg13, %arg13, %arg12, 0, 0, 0 :
+ (f64, f64, vector<4xf64>) -> vector<4xf64>
+
+ // CHECK: rocdl.mfma.f64.4x4x4f64 {{.*}} : (f64, f64, f64) -> f64
+ %r26 = rocdl.mfma.f64.4x4x4f64 %arg13, %arg13, %arg13, 0, 0, 0 :
+ (f64, f64, f64) -> f64
+
+ // CHECK: rocdl.mfma.i32.16x16x32.i8 {{.*}} : (i64, i64, vector<4xi32>) -> vector<4xi32>
+ %r27 = rocdl.mfma.i32.16x16x32.i8 %arg14, %arg14, %arg9, 0, 0, 0 :
+ (i64, i64, vector<4xi32>) -> vector<4xi32>
+
+ // CHECK: rocdl.mfma.i32.32x32x16.i8 {{.*}} : (i64, i64, vector<16xi32>) -> vector<16xi32>
+ %r28 = rocdl.mfma.i32.32x32x16.i8 %arg14, %arg14, %arg8, 0, 0, 0 :
+ (i64, i64, vector<16xi32>) -> vector<16xi32>
+
+ // CHECK: rocdl.mfma.f32.16x16x8.xf32 {{.*}} : (vector<2xf32>, vector<2xf32>, vector<4xf32>) -> vector<4xf32>
+ %r29 = rocdl.mfma.f32.16x16x8.xf32 %arg15, %arg15, %arg5, 0, 0, 0 :
+ (vector<2xf32>, vector<2xf32>, vector<4xf32>) -> vector<4xf32>
+
+ // CHECK: rocdl.mfma.f32.32x32x4.xf32 {{.*}} : (vector<2xf32>, vector<2xf32>, vector<16xf32>) -> vector<16xf32>
+ %r30 = rocdl.mfma.f32.32x32x4.xf32 %arg15, %arg15, %arg4, 0, 0, 0 :
+ (vector<2xf32>, vector<2xf32>, vector<16xf32>) -> vector<16xf32>
+
+ // CHECK: rocdl.mfma.f32.16x16x32.bf16 {{.*}} : (vector<8xbf16>, vector<8xbf16>, vector<4xf32>) -> vector<4xf32>
+ %r31 = rocdl.mfma.f32.16x16x32.bf16 %arg16, %arg16, %arg5, 0, 0, 0 :
+ (vector<8xbf16>, vector<8xbf16>, vector<4xf32>) -> vector<4xf32>
+
+ // CHECK: rocdl.mfma.i32.16x16x64.i8 {{.*}} : (vector<4xi32>, vector<4xi32>, vector<4xi32>) -> vector<4xi32>
+ %r32 = rocdl.mfma.i32.16x16x64.i8 %arg9, %arg9, %arg9, 0, 0, 0 :
+ (vector<4xi32>, vector<4xi32>, vector<4xi32>) -> vector<4xi32>
+
+ // CHECK: rocdl.mfma.f32.16x16x32.f16 {{.*}} : (vector<8xf16>, vector<8xf16>, vector<4xf32>) -> vector<4xf32>
+ %r33 = rocdl.mfma.f32.16x16x32.f16 %arg17, %arg17, %arg5, 0, 0, 0 :
+ (vector<8xf16>, vector<8xf16>, vector<4xf32>) -> vector<4xf32>
+
+ // CHECK: rocdl.mfma.f32.32x32x16.bf16 {{.*}} : (vector<8xbf16>, vector<8xbf16>, vector<16xf32>) -> vector<16xf32>
+ %r34 = rocdl.mfma.f32.32x32x16.bf16 %arg16, %arg16, %arg4, 0, 0, 0 :
+ (vector<8xbf16>, vector<8xbf16>, vector<16xf32>) -> vector<16xf32>
+
+ // CHECK: rocdl.mfma.i32.32x32x32.i8 {{.*}} : (vector<4xi32>, vector<4xi32>, vector<16xi32>) -> vector<16xi32>
+ %r35 = rocdl.mfma.i32.32x32x32.i8 %arg9, %arg9, %arg8, 0, 0, 0 :
+ (vector<4xi32>, vector<4xi32>, vector<16xi32>) -> vector<16xi32>
+
+ // CHECK: rocdl.mfma.f32.32x32x16.f16 {{.*}} : (vector<8xf16>, vector<8xf16>, vector<16xf32>) -> vector<16xf32>
+ %r36 = rocdl.mfma.f32.32x32x16.f16 %arg17, %arg17, %arg4, 0, 0, 0 :
+ (vector<8xf16>, vector<8xf16>, vector<16xf32>) -> vector<16xf32>
llvm.return
}
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index af77f09c091d5..1e7ff3aaba622 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -410,178 +410,143 @@ llvm.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
%arg10 : vector<2xi16>, %arg11 : i64,
%arg12 : vector<8xbf16>, %arg13 : vector<4xi32>,
%arg14 : vector<8xf16>) -> vector<32 x f32> {
- %csti32 = llvm.mlir.constant(42 : i32) : i32
// CHECK-LABEL: rocdl.xdlops
// CHECK: call <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float %{{.*}}, float %{{.*}}, <32 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r0 = rocdl.mfma.f32.32x32x1f32 %arg0, %arg1, %arg2, %csti32, %csti32, %csti32 :
- (f32, f32, vector<32 x f32>,
- i32, i32, i32) -> vector<32 x f32>
+ %r0 = rocdl.mfma.f32.32x32x1f32 %arg0, %arg1, %arg2, 0, 0, 0 :
+ (f32, f32, vector<32 x f32>) -> vector<32 x f32>
// CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.16x16x1f32(float %{{.*}}, float %{{.*}}, <16 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r1 = rocdl.mfma.f32.16x16x1f32 %arg0, %arg1, %arg4, %csti32, %csti32, %csti32 :
- (f32, f32, vector<16 x f32>,
- i32, i32, i32) -> vector<16 x f32>
+ %r1 = rocdl.mfma.f32.16x16x1f32 %arg0, %arg1, %arg4, 0, 0, 0 :
+ (f32, f32, vector<16 x f32>) -> vector<16 x f32>
// CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x4f32(float %{{.*}}, float %{{.*}}, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r2 = rocdl.mfma.f32.16x16x4f32 %arg0, %arg1, %arg5, %csti32, %csti32, %csti32 :
- (f32, f32, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
+ %r2 = rocdl.mfma.f32.16x16x4f32 %arg0, %arg1, %arg5, 0, 0, 0 :
+ (f32, f32, vector<4xf32>) -> vector<4xf32>
// CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.4x4x1f32(float %{{.*}}, float %{{.*}}, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r3 = rocdl.mfma.f32.4x4x1f32 %arg0, %arg1, %arg5, %csti32, %csti32, %csti32 :
- (f32, f32, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
+ %r3 = rocdl.mfma.f32.4x4x1f32 %arg0, %arg1, %arg5, 0, 0, 0 :
+ (f32, f32, vector<4xf32>) -> vector<4xf32>
// CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x2f32(float %{{.*}}, float %{{.*}}, <16 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r4= rocdl.mfma.f32.32x32x2f32 %arg0, %arg1, %arg4, %csti32, %csti32, %csti32 :
- (f32, f32, vector<16 x f32>,
- i32, i32, i32) -> vector<16 x f32>
+ %r4= rocdl.mfma.f32.32x32x2f32 %arg0, %arg1, %arg4, 0, 0, 0 :
+ (f32, f32, vector<16 x f32>) -> vector<16 x f32>
// CHECK: call <32 x float> @llvm.amdgcn.mfma.f32.32x32x4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, <32 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r5 = rocdl.mfma.f32.32x32x4f16 %arg6, %arg6, %arg2, %csti32, %csti32, %csti32 :
- (vector<4xf16>, vector<4xf16>, vector<32 x f32>,
- i32, i32, i32) -> vector<32 x f32>
+ %r5 = rocdl.mfma.f32.32x32x4f16 %arg6, %arg6, %arg2, 0, 0, 0 :
+ (vector<4xf16>, vector<4xf16>, vector<32 x f32>) -> vector<32 x f32>
// CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.16x16x4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, <16 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r6 = rocdl.mfma.f32.16x16x4f16 %arg6, %arg6, %arg4, %csti32, %csti32, %csti32 :
- (vector<4xf16>, vector<4xf16>, vector<16 x f32>,
- i32, i32, i32) -> vector<16 x f32>
+ %r6 = rocdl.mfma.f32.16x16x4f16 %arg6, %arg6, %arg4, 0, 0, 0 :
+ (vector<4xf16>, vector<4xf16>, vector<16 x f32>) -> vector<16 x f32>
// CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.4x4x4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r7 = rocdl.mfma.f32.4x4x4f16 %arg6, %arg6, %arg5, %csti32, %csti32, %csti32 :
- (vector<4xf16>, vector<4xf16>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
+ %r7 = rocdl.mfma.f32.4x4x4f16 %arg6, %arg6, %arg5, 0, 0, 0 :
+ (vector<4xf16>, vector<4xf16>, vector<4xf32>) -> vector<4xf32>
// CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x8f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, <16 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r8 = rocdl.mfma.f32.32x32x8f16 %arg6, %arg6, %arg4, %csti32, %csti32, %csti32 :
- (vector<4xf16>, vector<4xf16>, vector<16 x f32>,
- i32, i32, i32) -> vector<16 x f32>
+ %r8 = rocdl.mfma.f32.32x32x8f16 %arg6, %arg6, %arg4, 0, 0, 0 :
+ (vector<4xf16>, vector<4xf16>, vector<16 x f32>) -> vector<16 x f32>
// CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x16f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r9 = rocdl.mfma.f32.16x16x16f16 %arg6, %arg6, %arg5, %csti32, %csti32, %csti32 :
- (vector<4xf16>, vector<4xf16>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
+ %r9 = rocdl.mfma.f32.16x16x16f16 %arg6, %arg6, %arg5, 0, 0, 0 :
+ (vector<4xf16>, vector<4xf16>, vector<4xf32>) -> vector<4xf32>
// CHECK: call <32 x i32> @llvm.amdgcn.mfma.i32.32x32x4i8(i32 %{{.*}}, i32 %{{.*}}, <32 x i32> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r10 = rocdl.mfma.i32.32x32x4i8 %arg3, %arg3, %arg7, %csti32, %csti32, %csti32 :
- (i32, i32, vector<32 x i32>,
- i32, i32, i32) -> vector<32 x i32>
+ %r10 = rocdl.mfma.i32.32x32x4i8 %arg3, %arg3, %arg7, 0, 0, 0 :
+ (i32, i32, vector<32 x i32>) -> vector<32 x i32>
// CHECK: call <16 x i32> @llvm.amdgcn.mfma.i32.16x16x4i8(i32 %{{.*}}, i32 %{{.*}}, <16 x i32> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r11 = rocdl.mfma.i32.16x16x4i8 %arg3, %arg3, %arg8, %csti32, %csti32, %csti32 :
- (i32, i32, vector<16 x i32>,
- i32, i32, i32) -> vector<16 x i32>
+ %r11 = rocdl.mfma.i32.16x16x4i8 %arg3, %arg3, %arg8, 0, 0, 0 :
+ (i32, i32, vector<16 x i32>) -> vector<16 x i32>
// CHECK: call <4 x i32> @llvm.amdgcn.mfma.i32.4x4x4i8(i32 %{{.*}}, i32 %{{.*}}, <4 x i32> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r12 = rocdl.mfma.i32.4x4x4i8 %arg3, %arg3, %arg9, %csti32, %csti32, %csti32 :
- (i32, i32, vector<4xi32>,
- i32, i32, i32) -> vector<4xi32>
+ %r12 = rocdl.mfma.i32.4x4x4i8 %arg3, %arg3, %arg9, 0, 0, 0 :
+ (i32, i32, vector<4xi32>) -> vector<4xi32>
// CHECK: call <16 x i32> @llvm.amdgcn.mfma.i32.32x32x8i8(i32 %{{.*}}, i32 %{{.*}}, <16 x i32> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r13 = rocdl.mfma.i32.32x32x8i8 %arg3, %arg3, %arg8, %csti32, %csti32, %csti32 :
- (i32, i32, vector<16 x i32>,
- i32, i32, i32) -> vector<16 x i32>
+ %r13 = rocdl.mfma.i32.32x32x8i8 %arg3, %arg3, %arg8, 0, 0, 0 :
+ (i32, i32, vector<16 x i32>) -> vector<16 x i32>
// CHECK: call <4 x i32> @llvm.amdgcn.mfma.i32.16x16x16i8(i32 %{{.*}}, i32 %{{.*}}, <4 x i32> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r14 = rocdl.mfma.i32.16x16x16i8 %arg3, %arg3, %arg9, %csti32, %csti32, %csti32 :
- (i32, i32, vector<4xi32>,
- i32, i32, i32) -> vector<4xi32>
+ %r14 = rocdl.mfma.i32.16x16x16i8 %arg3, %arg3, %arg9, 0, 0, 0 :
+ (i32, i32, vector<4xi32>) -> vector<4xi32>
// CHECK: call <32 x float> @llvm.amdgcn.mfma.f32.32x32x2bf16(<2 x i16> %{{.*}}, <2 x i16> %{{.*}}, <32 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r15 = rocdl.mfma.f32.32x32x2bf16 %arg10, %arg10, %arg2, %csti32, %csti32, %csti32 :
- (vector<2xi16>, vector<2xi16>, vector<32 x f32>,
- i32, i32, i32) -> vector<32 x f32>
+ %r15 = rocdl.mfma.f32.32x32x2bf16 %arg10, %arg10, %arg2, 0, 0, 0 :
+ (vector<2xi16>, vector<2xi16>, vector<32 x f32>) -> vector<32 x f32>
// CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.16x16x2bf16(<2 x i16> %{{.*}}, <2 x i16> %{{.*}}, <16 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r16 = rocdl.mfma.f32.16x16x2bf16 %arg10, %arg10, %arg4, %csti32, %csti32, %csti32 :
- (vector<2xi16>, vector<2xi16>, vector<16 x f32>,
- i32, i32, i32) -> vector<16 x f32>
+ %r16 = rocdl.mfma.f32.16x16x2bf16 %arg10, %arg10, %arg4, 0, 0, 0 :
+ (vector<2xi16>, vector<2xi16>, vector<16 x f32>) -> vector<16 x f32>
// CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.4x4x2bf16(<2 x i16> %{{.*}}, <2 x i16> %{{.*}}, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r17 = rocdl.mfma.f32.4x4x2bf16 %arg10, %arg10, %arg5, %csti32, %csti32, %csti32 :
- (vector<2xi16>, vector<2xi16>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
+ %r17 = rocdl.mfma.f32.4x4x2bf16 %arg10, %arg10, %arg5, 0, 0, 0 :
+ (vector<2xi16>, vector<2xi16>, vector<4xf32>) -> vector<4xf32>
// CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x4bf16(<2 x i16> %{{.*}}, <2 x i16> %{{.*}}, <16 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r18 = rocdl.mfma.f32.32x32x4bf16 %arg10, %arg10, %arg4, %csti32, %csti32, %csti32 :
- (vector<2xi16>, vector<2xi16>, vector<16 x f32>,
- i32, i32, i32) -> vector<16 x f32>
+ %r18 = rocdl.mfma.f32.32x32x4bf16 %arg10, %arg10, %arg4, 0, 0, 0 :
+ (vector<2xi16>, vector<2xi16>, vector<16 x f32>) -> vector<16 x f32>
// CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x8bf16(<2 x i16> %{{.*}}, <2 x i16> %{{.*}}, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r19 = rocdl.mfma.f32.16x16x8bf16 %arg10, %arg10, %arg5, %csti32, %csti32, %csti32 :
- (vector<2xi16>, vector<2xi16>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
+ %r19 = rocdl.mfma.f32.16x16x8bf16 %arg10, %arg10, %arg5, 0, 0, 0 :
+ (vector<2xi16>, vector<2xi16>, vector<4xf32>) -> vector<4xf32>
// CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x32.bf8.bf8(i64 %{{.*}}, i64 %{{.*}}, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r20 = rocdl.mfma.f32.16x16x32.bf8.bf8 %arg11, %arg11, %arg5, %csti32, %csti32, %csti32 :
- (i64, i64, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
+ %r20 = rocdl.mfma.f32.16x16x32.bf8.bf8 %arg11, %arg11, %arg5, 0, 0, 0 :
+ (i64, i64, vector<4xf32>) -> vector<4xf32>
// CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x32.bf8.fp8(i64 %{{.*}}, i64 %{{.*}}, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r21 = rocdl.mfma.f32.16x16x32.bf8.fp8 %arg11, %arg11, %arg5, %csti32, %csti32, %csti32 :
- (i64, i64, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
+ %r21 = rocdl.mfma.f32.16x16x32.bf8.fp8 %arg11, %arg11, %arg5, 0, 0, 0 :
+ (i64, i64, vector<4xf32>) -> vector<4xf32>
// CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x32.fp8.bf8(i64 %{{.*}}, i64 %{{.*}}, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r22 = rocdl.mfma.f32.16x16x32.fp8.bf8 %arg11, %arg11, %arg5, %csti32, %csti32, %csti32 :
- (i64, i64, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
+ %r22 = rocdl.mfma.f32.16x16x32.fp8.bf8 %arg11, %arg11, %arg5, 0, 0, 0 :
+ (i64, i64, vector<4xf32>) -> vector<4xf32>
// CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x32.fp8.fp8(i64 %{{.*}}, i64 %{{.*}}, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r23 = rocdl.mfma.f32.16x16x32.fp8.fp8 %arg11, %arg11, %arg5, %csti32, %csti32, %csti32 :
- (i64, i64, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
+ %r23 = rocdl.mfma.f32.16x16x32.fp8.fp8 %arg11, %arg11, %arg5, 0, 0, 0 :
+ (i64, i64, vector<4xf32>) -> vector<4xf32>
// CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x16.bf8.bf8(i64 %{{.*}}, i64 %{{.*}}, <16 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r24 = rocdl.mfma.f32.32x32x16.bf8.bf8 %arg11, %arg11, %arg4, %csti32, %csti32, %csti32 :
- (i64, i64, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
+ %r24 = rocdl.mfma.f32.32x32x16.bf8.bf8 %arg11, %arg11, %arg4, 0, 0, 0 :
+ (i64, i64, vector<16xf32>) -> vector<16xf32>
// CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x16.bf8.fp8(i64 %{{.*}}, i64 %{{.*}}, <16 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r25 = rocdl.mfma.f32.32x32x16.bf8.fp8 %arg11, %arg11, %arg4, %csti32, %csti32, %csti32 :
- (i64, i64, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
+ %r25 = rocdl.mfma.f32.32x32x16.bf8.fp8 %arg11, %arg11, %arg4, 0, 0, 0 :
+ (i64, i64, vector<16xf32>) -> vector<16xf32>
// CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x16.fp8.bf8(i64 %{{.*}}, i64 %{{.*}}, <16 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r26 = rocdl.mfma.f32.32x32x16.fp8.bf8 %arg11, %arg11, %arg4, %csti32, %csti32, %csti32 :
- (i64, i64, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
+ %r26 = rocdl.mfma.f32.32x32x16.fp8.bf8 %arg11, %arg11, %arg4, 0, 0, 0 :
+ (i64, i64, vector<16xf32>) -> vector<16xf32>
// CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x16.bf8.bf8(i64 %{{.*}}, i64 %{{.*}}, <16 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r27 = rocdl.mfma.f32.32x32x16.bf8.bf8 %arg11, %arg11, %arg4, %csti32, %csti32, %csti32 :
- (i64, i64, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
+ %r27 = rocdl.mfma.f32.32x32x16.bf8.bf8 %arg11, %arg11, %arg4, 0, 0, 0 :
+ (i64, i64, vector<16xf32>) -> vector<16xf32>
// CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x32.bf16(<8 x bfloat> %{{.*}}, <8 x bfloat> %{{.*}}, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r28 = rocdl.mfma.f32.16x16x32.bf16 %arg12, %arg12, %arg5, %csti32, %csti32, %csti32 :
- (vector<8xbf16>, vector<8xbf16>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
+ %r28 = rocdl.mfma.f32.16x16x32.bf16 %arg12, %arg12, %arg5, 0, 0, 0 :
+ (vector<8xbf16>, vector<8xbf16>, vector<4xf32>) -> vector<4xf32>
// CHECK: call <4 x i32> @llvm.amdgcn.mfma.i32.16x16x64.i8(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r29 = rocdl.mfma.i32.16x16x64.i8 %arg9, %arg9, %arg9, %csti32, %csti32, %csti32 :
- (vector<4xi32>, vector<4xi32>, vector<4xi32>,
- i32, i32, i32) -> vector<4xi32>
+ %r29 = rocdl.mfma.i32.16x16x64.i8 %arg9, %arg9, %arg9, 0, 0, 0 :
+ (vector<4xi32>, vector<4xi32>, vector<4xi32>) -> vector<4xi32>
// CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x32.f16(<8 x half> %{{.*}}, <8 x half> %{{.*}}, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r30 = rocdl.mfma.f32.16x16x32.f16 %arg14, %arg14, %arg5, %csti32, %csti32, %csti32 :
- (vector<8xf16>, vector<8xf16>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
+ %r30 = rocdl.mfma.f32.16x16x32.f16 %arg14, %arg14, %arg5, 0, 0, 0 :
+ (vector<8xf16>, vector<8xf16>, vector<4xf32>) -> vector<4xf32>
// CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x16.bf16(<8 x bfloat> %1{{.*}}, <8 x bfloat> %{{.*}}, <16 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r31 = rocdl.mfma.f32.32x32x16.bf16 %arg12, %arg12, %arg4, %csti32, %csti32, %csti32 :
- (vector<8xbf16>, vector<8xbf16>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
+ %r31 = rocdl.mfma.f32.32x32x16.bf16 %arg12, %arg12, %arg4, 0, 0, 0 :
+ (vector<8xbf16>, vector<8xbf16>, vector<16xf32>) -> vector<16xf32>
// CHECK: call <16 x i32> @llvm.amdgcn.mfma.i32.32x32x32.i8(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x i32> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r32 = rocdl.mfma.i32.32x32x32.i8 %arg9, %arg9, %arg8, %csti32, %csti32, %csti32 :
- (vector<4xi32>, vector<4xi32>, vector<16xi32>,
- i32, i32, i32) -> vector<16xi32>
+ %r32 = rocdl.mfma.i32.32x32x32.i8 %arg9, %arg9, %arg8, 0, 0, 0 :
+ (vector<4xi32>, vector<4xi32>, vector<16xi32>) -> vector<16xi32>
// CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x16.f16(<8 x half> %{{.*}}, <8 x half> %{{.*}}, <16 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
- %r33 = rocdl.mfma.f32.32x32x16.f16 %arg14, %arg14, %arg4, %csti32, %csti32, %csti32 :
- (vector<8xf16>, vector<8xf16>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
+ %r33 = rocdl.mfma.f32.32x32x16.f16 %arg14, %arg14, %arg4, 0, 0, 0 :
+ (vector<8xf16>, vector<8xf16>, vector<16xf32>) -> vector<16xf32>
llvm.return %r0 : vector<32 x f32>
}
>From fda566e4346e93c9abad7e6245dab01e05ae7774 Mon Sep 17 00:00:00 2001
From: ravil-mobile <ravil.aviva.com at gmail.com>
Date: Tue, 13 Jan 2026 16:37:18 +0000
Subject: [PATCH 3/6] [ROCDL] Added type and structual constrains to SMFMAC Ops
---
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 170 +++++++------
mlir/test/Dialect/LLVMIR/rocdl.mlir | 252 +++++++++----------
mlir/test/Target/LLVMIR/rocdl.mlir | 252 +++++++++----------
3 files changed, 313 insertions(+), 361 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 2add87da53e81..9ad7672e59cdb 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -610,14 +610,7 @@ def ROCDL_IglpOpt : ROCDL_ConcreteNonMemIntrOp<"iglp.opt", [], 0, [0], ["variant
//===---------------------------------------------------------------------===//
// Xdlops intrinsics
-class ROCDL_Mfma_IntrOp<string mnemonic, list<Trait> traits = []> :
- ROCDL_IntrOp<mnemonic, [], [], traits, 1>,
- Arguments<(ins Variadic<LLVM_Type>:$args)> {
- let assemblyFormat =
- "$args attr-dict `:` functional-type($args, $res)";
-}
-
-class ROCDL_Mfma_IntrOpV0<string mnemonic, Type AB, Type CD> :
+class ROCDL_Mfma_IntrOp<string mnemonic, Type AB, Type CD> :
ROCDL_IntrOp<mnemonic, [], [], [], 1, 0, 0, 0, [3, 4, 5], ["cbsz", "abid", "blgp"]>,
Arguments<(ins
LLVM_ScalarOrVectorOf<AB>:$a,
@@ -632,6 +625,21 @@ class ROCDL_Mfma_IntrOpV0<string mnemonic, Type AB, Type CD> :
}];
}
+class ROCDL_Smfmac_IntrOp<string mnemonic, Type AB, Type CD> :
+ ROCDL_IntrOp<mnemonic, [], [], [], 1, 0, 0, 0, [4, 5], ["cbsz", "abid"]>,
+ Arguments<(ins
+ LLVM_ScalarOrVectorOf<AB>:$a,
+ LLVM_ScalarOrVectorOf<AB>:$b,
+ LLVM_ScalarOrVectorOf<CD>:$c,
+ I32:$index,
+ I32Attr:$cbsz,
+ I32Attr:$abid)> {
+ let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+ let assemblyFormat = [{
+ $a `,` $b `,` $c `,` $index `,` $cbsz `,` $abid attr-dict `:` functional-type(operands, $res)
+ }];
+}
+
//===---------------------------------------------------------------------===//
// MFMA intrinsics with overloaded operands
class ROCDL_Mfma_OO_IntrOp<string mnemonic, list<int> overloadedOperands,
@@ -643,89 +651,89 @@ class ROCDL_Mfma_OO_IntrOp<string mnemonic, list<int> overloadedOperands,
}
// Available on all CDNA.
-def ROCDL_mfma_f32_32x32x1f32 : ROCDL_Mfma_IntrOpV0<"mfma.f32.32x32x1f32", /*Type AB=*/F32, /*Type CD=*/F32>;
-def ROCDL_mfma_f32_16x16x1f32 : ROCDL_Mfma_IntrOpV0<"mfma.f32.16x16x1f32", F32, F32>;
-def ROCDL_mfma_f32_4x4x1f32 : ROCDL_Mfma_IntrOpV0<"mfma.f32.4x4x1f32", F32, F32>;
-def ROCDL_mfma_f32_32x32x2f32 : ROCDL_Mfma_IntrOpV0<"mfma.f32.32x32x2f32", F32, F32>;
-def ROCDL_mfma_f32_16x16x4f32 : ROCDL_Mfma_IntrOpV0<"mfma.f32.16x16x4f32", F32, F32>;
-def ROCDL_mfma_f32_32x32x4f16 : ROCDL_Mfma_IntrOpV0<"mfma.f32.32x32x4f16", F16, F32>;
-def ROCDL_mfma_f32_16x16x4f16 : ROCDL_Mfma_IntrOpV0<"mfma.f32.16x16x4f16", F16, F32>;
-def ROCDL_mfma_f32_4x4x4f16 : ROCDL_Mfma_IntrOpV0<"mfma.f32.4x4x4f16", F16, F32>;
-def ROCDL_mfma_f32_32x32x8f16 : ROCDL_Mfma_IntrOpV0<"mfma.f32.32x32x8f16", F16, F32>;
-def ROCDL_mfma_f32_16x16x16f16 : ROCDL_Mfma_IntrOpV0<"mfma.f32.16x16x16f16", F16, F32>;
-def ROCDL_mfma_i32_32x32x4i8 : ROCDL_Mfma_IntrOpV0<"mfma.i32.32x32x4i8", I32, I32>;
-def ROCDL_mfma_i32_16x16x4i8 : ROCDL_Mfma_IntrOpV0<"mfma.i32.16x16x4i8", I32, I32>;
-def ROCDL_mfma_i32_4x4x4i8 : ROCDL_Mfma_IntrOpV0<"mfma.i32.4x4x4i8", I32, I32>;
-def ROCDL_mfma_i32_32x32x8i8 : ROCDL_Mfma_IntrOpV0<"mfma.i32.32x32x8i8", I32, I32>;
-def ROCDL_mfma_i32_16x16x16i8 : ROCDL_Mfma_IntrOpV0<"mfma.i32.16x16x16i8", I32, I32>;
-def ROCDL_mfma_f32_32x32x2bf16 : ROCDL_Mfma_IntrOpV0<"mfma.f32.32x32x2bf16", I16, F32>;
-def ROCDL_mfma_f32_16x16x2bf16 : ROCDL_Mfma_IntrOpV0<"mfma.f32.16x16x2bf16", I16, F32>;
-def ROCDL_mfma_f32_4x4x2bf16 : ROCDL_Mfma_IntrOpV0<"mfma.f32.4x4x2bf16", I16, F32>;
-def ROCDL_mfma_f32_32x32x4bf16 : ROCDL_Mfma_IntrOpV0<"mfma.f32.32x32x4bf16", I16, F32>;
-def ROCDL_mfma_f32_16x16x8bf16 : ROCDL_Mfma_IntrOpV0<"mfma.f32.16x16x8bf16", I16, F32>;
+def ROCDL_mfma_f32_32x32x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x1f32", /*Type AB=*/F32, /*Type CD=*/F32>;
+def ROCDL_mfma_f32_16x16x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x1f32", F32, F32>;
+def ROCDL_mfma_f32_4x4x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x1f32", F32, F32>;
+def ROCDL_mfma_f32_32x32x2f32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x2f32", F32, F32>;
+def ROCDL_mfma_f32_16x16x4f32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x4f32", F32, F32>;
+def ROCDL_mfma_f32_32x32x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4f16", F16, F32>;
+def ROCDL_mfma_f32_16x16x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x4f16", F16, F32>;
+def ROCDL_mfma_f32_4x4x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x4f16", F16, F32>;
+def ROCDL_mfma_f32_32x32x8f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x8f16", F16, F32>;
+def ROCDL_mfma_f32_16x16x16f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x16f16", F16, F32>;
+def ROCDL_mfma_i32_32x32x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x4i8", I32, I32>;
+def ROCDL_mfma_i32_16x16x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x4i8", I32, I32>;
+def ROCDL_mfma_i32_4x4x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.4x4x4i8", I32, I32>;
+def ROCDL_mfma_i32_32x32x8i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x8i8", I32, I32>;
+def ROCDL_mfma_i32_16x16x16i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x16i8", I32, I32>;
+def ROCDL_mfma_f32_32x32x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x2bf16", I16, F32>;
+def ROCDL_mfma_f32_16x16x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x2bf16", I16, F32>;
+def ROCDL_mfma_f32_4x4x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x2bf16", I16, F32>;
+def ROCDL_mfma_f32_32x32x4bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4bf16", I16, F32>;
+def ROCDL_mfma_f32_16x16x8bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x8bf16", I16, F32>;
// New in gfx90a.
-def ROCDL_mfma_f32_32x32x4bf16_1k : ROCDL_Mfma_IntrOpV0<"mfma.f32.32x32x4bf16.1k", I16, F32>;
-def ROCDL_mfma_f32_16x16x4bf16_1k : ROCDL_Mfma_IntrOpV0<"mfma.f32.16x16x4bf16.1k", I16, F32>;
-def ROCDL_mfma_f32_4x4x4bf16_1k : ROCDL_Mfma_IntrOpV0<"mfma.f32.4x4x4bf16.1k", I16, F32>;
-def ROCDL_mfma_f32_32x32x8bf16_1k : ROCDL_Mfma_IntrOpV0<"mfma.f32.32x32x8bf16.1k", I16, F32>;
-def ROCDL_mfma_f32_16x16x16bf16_1k : ROCDL_Mfma_IntrOpV0<"mfma.f32.16x16x16bf16.1k", I16, F32>;
+def ROCDL_mfma_f32_32x32x4bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4bf16.1k", I16, F32>;
+def ROCDL_mfma_f32_16x16x4bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x4bf16.1k", I16, F32>;
+def ROCDL_mfma_f32_4x4x4bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x4bf16.1k", I16, F32>;
+def ROCDL_mfma_f32_32x32x8bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x8bf16.1k", I16, F32>;
+def ROCDL_mfma_f32_16x16x16bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x16bf16.1k", I16, F32>;
// Note: in gfx94x, unlike in gfx90a, the f64 xdlops use the "blgp" argument as
// a NEG bitfield. See IntrinsicsAMDGPU.td for more info.
-def ROCDL_mfma_f64_16x16x4f64 : ROCDL_Mfma_IntrOpV0<"mfma.f64.16x16x4f64", F64, F64>;
-def ROCDL_mfma_f64_4x4x4f64 : ROCDL_Mfma_IntrOpV0<"mfma.f64.4x4x4f64", F64, F64>;
+def ROCDL_mfma_f64_16x16x4f64 : ROCDL_Mfma_IntrOp<"mfma.f64.16x16x4f64", F64, F64>;
+def ROCDL_mfma_f64_4x4x4f64 : ROCDL_Mfma_IntrOp<"mfma.f64.4x4x4f64", F64, F64>;
// New in gfx94x.
-def ROCDL_mfma_i32_16x16x32_i8 : ROCDL_Mfma_IntrOpV0<"mfma.i32.16x16x32.i8", I64, I32>;
-def ROCDL_mfma_i32_32x32x16_i8 : ROCDL_Mfma_IntrOpV0<"mfma.i32.32x32x16.i8", I64, I32>;
-def ROCDL_mfma_f32_16x16x8_xf32 : ROCDL_Mfma_IntrOpV0<"mfma.f32.16x16x8.xf32", F32, F32>;
-def ROCDL_mfma_f32_32x32x4_xf32 : ROCDL_Mfma_IntrOpV0<"mfma.f32.32x32x4.xf32", F32, F32>;
-def ROCDL_mfma_f32_16x16x32_bf8_bf8 : ROCDL_Mfma_IntrOpV0<"mfma.f32.16x16x32.bf8.bf8", I64, F32>;
-def ROCDL_mfma_f32_16x16x32_bf8_fp8 : ROCDL_Mfma_IntrOpV0<"mfma.f32.16x16x32.bf8.fp8", I64, F32>;
-def ROCDL_mfma_f32_16x16x32_fp8_bf8 : ROCDL_Mfma_IntrOpV0<"mfma.f32.16x16x32.fp8.bf8", I64, F32>;
-def ROCDL_mfma_f32_16x16x32_fp8_fp8 : ROCDL_Mfma_IntrOpV0<"mfma.f32.16x16x32.fp8.fp8", I64, F32>;
-def ROCDL_mfma_f32_32x32x16_bf8_bf8 : ROCDL_Mfma_IntrOpV0<"mfma.f32.32x32x16.bf8.bf8", I64, F32>;
-def ROCDL_mfma_f32_32x32x16_bf8_fp8 : ROCDL_Mfma_IntrOpV0<"mfma.f32.32x32x16.bf8.fp8", I64, F32>;
-def ROCDL_mfma_f32_32x32x16_fp8_bf8 : ROCDL_Mfma_IntrOpV0<"mfma.f32.32x32x16.fp8.bf8", I64, F32>;
-def ROCDL_mfma_f32_32x32x16_fp8_fp8 : ROCDL_Mfma_IntrOpV0<"mfma.f32.32x32x16.fp8.fp8", I64, F32>;
+def ROCDL_mfma_i32_16x16x32_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x32.i8", I64, I32>;
+def ROCDL_mfma_i32_32x32x16_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x16.i8", I64, I32>;
+def ROCDL_mfma_f32_16x16x8_xf32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x8.xf32", F32, F32>;
+def ROCDL_mfma_f32_32x32x4_xf32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4.xf32", F32, F32>;
+def ROCDL_mfma_f32_16x16x32_bf8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.bf8.bf8", I64, F32>;
+def ROCDL_mfma_f32_16x16x32_bf8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.bf8.fp8", I64, F32>;
+def ROCDL_mfma_f32_16x16x32_fp8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.fp8.bf8", I64, F32>;
+def ROCDL_mfma_f32_16x16x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.fp8.fp8", I64, F32>;
+def ROCDL_mfma_f32_32x32x16_bf8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf8.bf8", I64, F32>;
+def ROCDL_mfma_f32_32x32x16_bf8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf8.fp8", I64, F32>;
+def ROCDL_mfma_f32_32x32x16_fp8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.fp8.bf8", I64, F32>;
+def ROCDL_mfma_f32_32x32x16_fp8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.fp8.fp8", I64, F32>;
// New in gfx950.
-def ROCDL_mfma_f32_16x16x32_bf16 : ROCDL_Mfma_IntrOpV0<"mfma.f32.16x16x32.bf16", BF16, F32>;
-def ROCDL_mfma_i32_16x16x64_i8 : ROCDL_Mfma_IntrOpV0<"mfma.i32.16x16x64.i8", I32, I32>;
-def ROCDL_mfma_f32_16x16x32_f16 : ROCDL_Mfma_IntrOpV0<"mfma.f32.16x16x32.f16", F16, F32>;
-def ROCDL_mfma_f32_32x32x16_bf16 : ROCDL_Mfma_IntrOpV0<"mfma.f32.32x32x16.bf16", BF16, F32>;
-def ROCDL_mfma_i32_32x32x32_i8 : ROCDL_Mfma_IntrOpV0<"mfma.i32.32x32x32.i8", I32, I32>;
-def ROCDL_mfma_f32_32x32x16_f16 : ROCDL_Mfma_IntrOpV0<"mfma.f32.32x32x16.f16", F16, F32>;
+def ROCDL_mfma_f32_16x16x32_bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.bf16", BF16, F32>;
+def ROCDL_mfma_i32_16x16x64_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x64.i8", I32, I32>;
+def ROCDL_mfma_f32_16x16x32_f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.f16", F16, F32>;
+def ROCDL_mfma_f32_32x32x16_bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf16", BF16, F32>;
+def ROCDL_mfma_i32_32x32x32_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x32.i8", I32, I32>;
+def ROCDL_mfma_f32_32x32x16_f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.f16", F16, F32>;
def ROCDL_mfma_scale_f32_16x16x128_f8f6f4 : ROCDL_Mfma_OO_IntrOp<"mfma.scale.f32.16x16x128.f8f6f4", [0,1]>;
def ROCDL_mfma_scale_f32_32x32x64_f8f6f4 : ROCDL_Mfma_OO_IntrOp<"mfma.scale.f32.32x32x64.f8f6f4", [0,1]>;
// 2:4 Sparsity ops (GFX94x)
-def ROCDL_smfmac_f32_16x16x32_f16 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x32.f16">;
-def ROCDL_smfmac_f32_32x32x16_f16 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x16.f16">;
-def ROCDL_smfmac_f32_16x16x32_bf16 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x32.bf16">;
-def ROCDL_smfmac_f32_32x32x16_bf16 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x16.bf16">;
-def ROCDL_smfmac_i32_16x16x64_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.16x16x64.i8">;
-def ROCDL_smfmac_i32_32x32x32_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.32x32x32.i8">;
-def ROCDL_smfmac_f32_16x16x64_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.bf8.bf8">;
-def ROCDL_smfmac_f32_16x16x64_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.bf8.fp8">;
-def ROCDL_smfmac_f32_16x16x64_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.fp8.bf8">;
-def ROCDL_smfmac_f32_16x16x64_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.fp8.fp8">;
-def ROCDL_smfmac_f32_32x32x32_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf8.bf8">;
-def ROCDL_smfmac_f32_32x32x32_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf8.fp8">;
-def ROCDL_smfmac_f32_32x32x32_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.bf8">;
-def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.fp8">;
+def ROCDL_smfmac_f32_16x16x32_f16 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x32.f16", F16, F32>;
+def ROCDL_smfmac_f32_32x32x16_f16 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x16.f16", F16, F32>;
+def ROCDL_smfmac_f32_16x16x32_bf16 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x32.bf16", I16, F32>;
+def ROCDL_smfmac_f32_32x32x16_bf16 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x16.bf16", I16, F32>;
+def ROCDL_smfmac_i32_16x16x64_i8 : ROCDL_Smfmac_IntrOp<"smfmac.i32.16x16x64.i8", I32, I32>;
+def ROCDL_smfmac_i32_32x32x32_i8 : ROCDL_Smfmac_IntrOp<"smfmac.i32.32x32x32.i8", I32, I32>;
+def ROCDL_smfmac_f32_16x16x64_bf8_bf8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x64.bf8.bf8", I32, F32>;
+def ROCDL_smfmac_f32_16x16x64_bf8_fp8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x64.bf8.fp8", I32, F32>;
+def ROCDL_smfmac_f32_16x16x64_fp8_bf8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x64.fp8.bf8", I32, F32>;
+def ROCDL_smfmac_f32_16x16x64_fp8_fp8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x64.fp8.fp8", I32, F32>;
+def ROCDL_smfmac_f32_32x32x32_bf8_bf8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x32.bf8.bf8", I32, F32>;
+def ROCDL_smfmac_f32_32x32x32_bf8_fp8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x32.bf8.fp8", I32, F32>;
+def ROCDL_smfmac_f32_32x32x32_fp8_bf8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x32.fp8.bf8", I32, F32>;
+def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x32.fp8.fp8", I32, F32>;
// New in gfx950.
-def ROCDL_smfmac_f32_16x16x64_bf16 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.bf16">;
-def ROCDL_smfmac_f32_16x16x64_f16 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.f16">;
-def ROCDL_smfmac_i32_16x16x128_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.16x16x128.i8">;
-def ROCDL_smfmac_f32_16x16x128_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.bf8.bf8">;
-def ROCDL_smfmac_f32_16x16x128_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.bf8.fp8">;
-def ROCDL_smfmac_f32_16x16x128_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.fp8.bf8">;
-def ROCDL_smfmac_f32_16x16x128_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.fp8.fp8">;
-def ROCDL_smfmac_f32_32x32x32_bf16 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf16">;
-def ROCDL_smfmac_f32_32x32x32_f16 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.f16">;
-def ROCDL_smfmac_i32_32x32x64_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.32x32x64.i8">;
-def ROCDL_smfmac_f32_32x32x64_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.bf8.bf8">;
-def ROCDL_smfmac_f32_32x32x64_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.bf8.fp8">;
-def ROCDL_smfmac_f32_32x32x64_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.fp8.bf8">;
-def ROCDL_smfmac_f32_32x32x64_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.fp8.fp8">;
+def ROCDL_smfmac_f32_16x16x64_bf16 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x64.bf16", BF16, F32>;
+def ROCDL_smfmac_f32_16x16x64_f16 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x64.f16", F16, F32>;
+def ROCDL_smfmac_i32_16x16x128_i8 : ROCDL_Smfmac_IntrOp<"smfmac.i32.16x16x128.i8", I32, I32>;
+def ROCDL_smfmac_f32_16x16x128_bf8_bf8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x128.bf8.bf8", I32, F32>;
+def ROCDL_smfmac_f32_16x16x128_bf8_fp8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x128.bf8.fp8", I32, F32>;
+def ROCDL_smfmac_f32_16x16x128_fp8_bf8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x128.fp8.bf8", I32, F32>;
+def ROCDL_smfmac_f32_16x16x128_fp8_fp8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x128.fp8.fp8", I32, F32>;
+def ROCDL_smfmac_f32_32x32x32_bf16 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x32.bf16", BF16, F32>;
+def ROCDL_smfmac_f32_32x32x32_f16 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x32.f16", F16, F32>;
+def ROCDL_smfmac_i32_32x32x64_i8 : ROCDL_Smfmac_IntrOp<"smfmac.i32.32x32x64.i8", I32, I32>;
+def ROCDL_smfmac_f32_32x32x64_bf8_bf8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x64.bf8.bf8", I32, F32>;
+def ROCDL_smfmac_f32_32x32x64_bf8_fp8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x64.bf8.fp8", I32, F32>;
+def ROCDL_smfmac_f32_32x32x64_fp8_bf8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x64.fp8.bf8", I32, F32>;
+def ROCDL_smfmac_f32_32x32x64_fp8_fp8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x64.fp8.fp8", I32, F32>;
//===---------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 47eb2944f18f7..f7aa1cdf5bb9f 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -309,148 +309,120 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
%arg11 : vector<8 x bf16>,
%arg12 : vector<16 x bf16>,
%arg13 : vector<8 x i32>) -> vector<4 x f32> {
- %csti32 = llvm.mlir.constant(42 : i32) : i32
+ %index = llvm.mlir.constant(42 : i32) : i32
// CHECK-LABEL: rocdl.smfmac
- // CHECK: rocdl.smfmac.f32.16x16x32.f16 %{{.*}} : (vector<4xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- %r0 = rocdl.smfmac.f32.16x16x32.f16 %arg1, %arg2, %arg3, %csti32, %csti32, %csti32 :
- (vector<4xf16>, vector<8xf16>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: rocdl.smfmac.f32.32x32x16.f16 %{{.*}} : (vector<4xf16>, vector<8xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- %r1 = rocdl.smfmac.f32.32x32x16.f16 %arg1, %arg2, %arg4, %csti32, %csti32, %csti32 :
- (vector<4xf16>, vector<8xf16>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: rocdl.smfmac.f32.16x16x32.bf16 %{{.*}} : (vector<4xi16>, vector<8xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- %r2 = rocdl.smfmac.f32.16x16x32.bf16 %arg5, %arg6, %arg3, %csti32, %csti32, %csti32 :
- (vector<4xi16>, vector<8xi16>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: rocdl.smfmac.f32.32x32x16.bf16 %{{.*}} : (vector<4xi16>, vector<8xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- %r3 = rocdl.smfmac.f32.32x32x16.bf16 %arg5, %arg6, %arg4, %csti32, %csti32, %csti32 :
- (vector<4xi16>, vector<8xi16>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: rocdl.smfmac.i32.16x16x64.i8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
- %r4 = rocdl.smfmac.i32.16x16x64.i8 %arg7, %arg8, %arg8, %csti32, %csti32, %csti32 :
- (vector<2xi32>, vector<4xi32>, vector<4xi32>,
- i32, i32, i32) -> vector<4xi32>
-
- // CHECK: rocdl.smfmac.i32.32x32x32.i8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
- %r5 = rocdl.smfmac.i32.32x32x32.i8 %arg7, %arg8, %arg9, %csti32, %csti32, %csti32 :
- (vector<2xi32>, vector<4xi32>, vector<16xi32>,
- i32, i32, i32) -> vector<16xi32>
-
- // CHECK: rocdl.smfmac.f32.16x16x64.bf8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- %r6 = rocdl.smfmac.f32.16x16x64.bf8.bf8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
- (vector<2xi32>, vector<4xi32>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: rocdl.smfmac.f32.16x16x64.bf8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- %r7 = rocdl.smfmac.f32.16x16x64.bf8.fp8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
- (vector<2xi32>, vector<4xi32>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: rocdl.smfmac.f32.16x16x64.fp8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- %r8 = rocdl.smfmac.f32.16x16x64.fp8.bf8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
- (vector<2xi32>, vector<4xi32>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: rocdl.smfmac.f32.16x16x64.fp8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- %r9 = rocdl.smfmac.f32.16x16x64.fp8.fp8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
- (vector<2xi32>, vector<4xi32>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: rocdl.smfmac.f32.32x32x32.bf8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- %r10 = rocdl.smfmac.f32.32x32x32.bf8.bf8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
- (vector<2xi32>, vector<4xi32>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: rocdl.smfmac.f32.32x32x32.bf8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- %r11 = rocdl.smfmac.f32.32x32x32.bf8.fp8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
- (vector<2xi32>, vector<4xi32>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: rocdl.smfmac.f32.32x32x32.fp8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- %r12 = rocdl.smfmac.f32.32x32x32.fp8.bf8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
- (vector<2xi32>, vector<4xi32>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: rocdl.smfmac.f32.32x32x32.fp8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- %r13 = rocdl.smfmac.f32.32x32x32.fp8.fp8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
- (vector<2xi32>, vector<4xi32>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: rocdl.smfmac.f32.16x16x64.f16 %{{.*}} : (vector<8xf16>, vector<16xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- %r14 = rocdl.smfmac.f32.16x16x64.f16 %arg2, %arg10, %arg3, %csti32, %csti32, %csti32 :
- (vector<8xf16>, vector<16xf16>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: rocdl.smfmac.f32.32x32x32.f16 %{{.*}} : (vector<8xf16>, vector<16xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- %r15 = rocdl.smfmac.f32.32x32x32.f16 %arg2, %arg10, %arg4, %csti32, %csti32, %csti32 :
- (vector<8xf16>, vector<16xf16>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: rocdl.smfmac.f32.16x16x64.bf16 %{{.*}} : (vector<8xbf16>, vector<16xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- %r16 = rocdl.smfmac.f32.16x16x64.bf16 %arg11, %arg12, %arg3, %csti32, %csti32, %csti32 :
- (vector<8xbf16>, vector<16xbf16>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: rocdl.smfmac.f32.32x32x32.bf16 %{{.*}} : (vector<8xbf16>, vector<16xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- %r17 = rocdl.smfmac.f32.32x32x32.bf16 %arg11, %arg12, %arg4, %csti32, %csti32, %csti32 :
- (vector<8xbf16>, vector<16xbf16>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: rocdl.smfmac.i32.16x16x128.i8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
- %r18 = rocdl.smfmac.i32.16x16x128.i8 %arg8, %arg13, %arg8, %csti32, %csti32, %csti32 :
- (vector<4xi32>, vector<8xi32>, vector<4xi32>,
- i32, i32, i32) -> vector<4xi32>
-
- // CHECK: rocdl.smfmac.f32.16x16x128.bf8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- %r19 = rocdl.smfmac.f32.16x16x128.bf8.bf8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
- (vector<4xi32>, vector<8xi32>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: rocdl.smfmac.f32.16x16x128.bf8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- %r20 = rocdl.smfmac.f32.16x16x128.bf8.fp8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
- (vector<4xi32>, vector<8xi32>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: rocdl.smfmac.f32.16x16x128.fp8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- %r21 = rocdl.smfmac.f32.16x16x128.fp8.bf8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
- (vector<4xi32>, vector<8xi32>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: rocdl.smfmac.f32.16x16x128.fp8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- %r22 = rocdl.smfmac.f32.16x16x128.fp8.fp8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
- (vector<4xi32>, vector<8xi32>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: rocdl.smfmac.i32.32x32x64.i8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
- %r23 = rocdl.smfmac.i32.32x32x64.i8 %arg8, %arg13, %arg9, %csti32, %csti32, %csti32 :
- (vector<4xi32>, vector<8xi32>, vector<16xi32>,
- i32, i32, i32) -> vector<16xi32>
-
- // CHECK: rocdl.smfmac.f32.32x32x64.bf8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- %r24 = rocdl.smfmac.f32.32x32x64.bf8.bf8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
- (vector<4xi32>, vector<8xi32>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: rocdl.smfmac.f32.32x32x64.bf8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- %r25 = rocdl.smfmac.f32.32x32x64.bf8.fp8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
- (vector<4xi32>, vector<8xi32>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: rocdl.smfmac.f32.32x32x64.fp8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- %r26 = rocdl.smfmac.f32.32x32x64.fp8.bf8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
- (vector<4xi32>, vector<8xi32>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: rocdl.smfmac.f32.32x32x64.fp8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- %r27 = rocdl.smfmac.f32.32x32x64.fp8.fp8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
- (vector<4xi32>, vector<8xi32>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.smfmac.f32.16x16x32.f16 %{{.*}} : (vector<4xf16>, vector<8xf16>, vector<4xf32>, i32) -> vector<4xf32>
+ %r0 = rocdl.smfmac.f32.16x16x32.f16 %arg1, %arg2, %arg3, %index, 0, 0 :
+ (vector<4xf16>, vector<8xf16>, vector<4xf32>, i32) -> vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x16.f16 %{{.*}} : (vector<4xf16>, vector<8xf16>, vector<16xf32>, i32) -> vector<16xf32>
+ %r1 = rocdl.smfmac.f32.32x32x16.f16 %arg1, %arg2, %arg4, %index, 0, 0 :
+ (vector<4xf16>, vector<8xf16>, vector<16xf32>, i32) -> vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x32.bf16 %{{.*}} : (vector<4xi16>, vector<8xi16>, vector<4xf32>, i32) -> vector<4xf32>
+ %r2 = rocdl.smfmac.f32.16x16x32.bf16 %arg5, %arg6, %arg3, %index, 0, 0 :
+ (vector<4xi16>, vector<8xi16>, vector<4xf32>, i32) -> vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x16.bf16 %{{.*}} : (vector<4xi16>, vector<8xi16>, vector<16xf32>, i32) -> vector<16xf32>
+ %r3 = rocdl.smfmac.f32.32x32x16.bf16 %arg5, %arg6, %arg4, %index, 0, 0 :
+ (vector<4xi16>, vector<8xi16>, vector<16xf32>, i32) -> vector<16xf32>
+
+ // CHECK: rocdl.smfmac.i32.16x16x64.i8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xi32>, i32) -> vector<4xi32>
+ %r4 = rocdl.smfmac.i32.16x16x64.i8 %arg7, %arg8, %arg8, %index, 0, 0 :
+ (vector<2xi32>, vector<4xi32>, vector<4xi32>, i32) -> vector<4xi32>
+
+ // CHECK: rocdl.smfmac.i32.32x32x32.i8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xi32>, i32) -> vector<16xi32>
+ %r5 = rocdl.smfmac.i32.32x32x32.i8 %arg7, %arg8, %arg9, %index, 0, 0 :
+ (vector<2xi32>, vector<4xi32>, vector<16xi32>, i32) -> vector<16xi32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x64.bf8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32) -> vector<4xf32>
+ %r6 = rocdl.smfmac.f32.16x16x64.bf8.bf8 %arg7, %arg8, %arg3, %index, 0, 0 :
+ (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32) -> vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x64.bf8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32) -> vector<4xf32>
+ %r7 = rocdl.smfmac.f32.16x16x64.bf8.fp8 %arg7, %arg8, %arg3, %index, 0, 0 :
+ (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32) -> vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x64.fp8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32) -> vector<4xf32>
+ %r8 = rocdl.smfmac.f32.16x16x64.fp8.bf8 %arg7, %arg8, %arg3, %index, 0, 0 :
+ (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32) -> vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x64.fp8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32) -> vector<4xf32>
+ %r9 = rocdl.smfmac.f32.16x16x64.fp8.fp8 %arg7, %arg8, %arg3, %index, 0, 0 :
+ (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32) -> vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x32.bf8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32) -> vector<16xf32>
+ %r10 = rocdl.smfmac.f32.32x32x32.bf8.bf8 %arg7, %arg8, %arg4, %index, 0, 0 :
+ (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32) -> vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x32.bf8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32) -> vector<16xf32>
+ %r11 = rocdl.smfmac.f32.32x32x32.bf8.fp8 %arg7, %arg8, %arg4, %index, 0, 0 :
+ (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32) -> vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x32.fp8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32) -> vector<16xf32>
+ %r12 = rocdl.smfmac.f32.32x32x32.fp8.bf8 %arg7, %arg8, %arg4, %index, 0, 0 :
+ (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32) -> vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x32.fp8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32) -> vector<16xf32>
+ %r13 = rocdl.smfmac.f32.32x32x32.fp8.fp8 %arg7, %arg8, %arg4, %index, 0, 0 :
+ (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32) -> vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x64.f16 %{{.*}} : (vector<8xf16>, vector<16xf16>, vector<4xf32>, i32) -> vector<4xf32>
+ %r14 = rocdl.smfmac.f32.16x16x64.f16 %arg2, %arg10, %arg3, %index, 0, 0 :
+ (vector<8xf16>, vector<16xf16>, vector<4xf32>, i32) -> vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x32.f16 %{{.*}} : (vector<8xf16>, vector<16xf16>, vector<16xf32>, i32) -> vector<16xf32>
+ %r15 = rocdl.smfmac.f32.32x32x32.f16 %arg2, %arg10, %arg4, %index, 0, 0 :
+ (vector<8xf16>, vector<16xf16>, vector<16xf32>, i32) -> vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x64.bf16 %{{.*}} : (vector<8xbf16>, vector<16xbf16>, vector<4xf32>, i32) -> vector<4xf32>
+ %r16 = rocdl.smfmac.f32.16x16x64.bf16 %arg11, %arg12, %arg3, %index, 0, 0 :
+ (vector<8xbf16>, vector<16xbf16>, vector<4xf32>, i32) -> vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x32.bf16 %{{.*}} : (vector<8xbf16>, vector<16xbf16>, vector<16xf32>, i32) -> vector<16xf32>
+ %r17 = rocdl.smfmac.f32.32x32x32.bf16 %arg11, %arg12, %arg4, %index, 0, 0 :
+ (vector<8xbf16>, vector<16xbf16>, vector<16xf32>, i32) -> vector<16xf32>
+
+ // CHECK: rocdl.smfmac.i32.16x16x128.i8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xi32>, i32) -> vector<4xi32>
+ %r18 = rocdl.smfmac.i32.16x16x128.i8 %arg8, %arg13, %arg8, %index, 0, 0 :
+ (vector<4xi32>, vector<8xi32>, vector<4xi32>, i32) -> vector<4xi32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x128.bf8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32) -> vector<4xf32>
+ %r19 = rocdl.smfmac.f32.16x16x128.bf8.bf8 %arg8, %arg13, %arg3, %index, 0, 0 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32) -> vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x128.bf8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32) -> vector<4xf32>
+ %r20 = rocdl.smfmac.f32.16x16x128.bf8.fp8 %arg8, %arg13, %arg3, %index, 0, 0 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32) -> vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x128.fp8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32) -> vector<4xf32>
+ %r21 = rocdl.smfmac.f32.16x16x128.fp8.bf8 %arg8, %arg13, %arg3, %index, 0, 0 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32) -> vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x128.fp8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32) -> vector<4xf32>
+ %r22 = rocdl.smfmac.f32.16x16x128.fp8.fp8 %arg8, %arg13, %arg3, %index, 0, 0 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32) -> vector<4xf32>
+
+ // CHECK: rocdl.smfmac.i32.32x32x64.i8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xi32>, i32) -> vector<16xi32>
+ %r23 = rocdl.smfmac.i32.32x32x64.i8 %arg8, %arg13, %arg9, %index, 0, 0 :
+ (vector<4xi32>, vector<8xi32>, vector<16xi32>, i32) -> vector<16xi32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x64.bf8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32) -> vector<16xf32>
+ %r24 = rocdl.smfmac.f32.32x32x64.bf8.bf8 %arg8, %arg13, %arg4, %index, 0, 0 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32) -> vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x64.bf8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32) -> vector<16xf32>
+ %r25 = rocdl.smfmac.f32.32x32x64.bf8.fp8 %arg8, %arg13, %arg4, %index, 0, 0 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32) -> vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x64.fp8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32) -> vector<16xf32>
+ %r26 = rocdl.smfmac.f32.32x32x64.fp8.bf8 %arg8, %arg13, %arg4, %index, 0, 0 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32) -> vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x64.fp8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32) -> vector<16xf32>
+ %r27 = rocdl.smfmac.f32.32x32x64.fp8.fp8 %arg8, %arg13, %arg4, %index, 0, 0 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32) -> vector<16xf32>
llvm.return %r0 : vector<4 x f32>
}
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 1e7ff3aaba622..d9d766bc01a17 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -565,149 +565,121 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
%arg11 : vector<8 x bf16>,
%arg12 : vector<16 x bf16>,
%arg13 : vector<8 x i32>) -> vector<4 x f32> {
- %csti32 = llvm.mlir.constant(42 : i32) : i32
+ %index = llvm.mlir.constant(42 : i32) : i32
// CHECK-LABEL: rocdl.smfmac
- // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x32.f16(<4 x half> %{{.*}}, <8 x half> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r0 = rocdl.smfmac.f32.16x16x32.f16 %arg1, %arg2, %arg3, %csti32, %csti32, %csti32 :
- (vector<4xf16>, vector<8xf16>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x16.f16(<4 x half> %{{.*}}, <8 x half> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r1 = rocdl.smfmac.f32.32x32x16.f16 %arg1, %arg2, %arg4, %csti32, %csti32, %csti32 :
- (vector<4xf16>, vector<8xf16>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x32.bf16(<4 x i16> %{{.*}}, <8 x i16> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r2 = rocdl.smfmac.f32.16x16x32.bf16 %arg5, %arg6, %arg3, %csti32, %csti32, %csti32 :
- (vector<4xi16>, vector<8xi16>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x16.bf16(<4 x i16> %{{.*}}, <8 x i16> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r3 = rocdl.smfmac.f32.32x32x16.bf16 %arg5, %arg6, %arg4, %csti32, %csti32, %csti32 :
- (vector<4xi16>, vector<8xi16>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: call <4 x i32> @llvm.amdgcn.smfmac.i32.16x16x64.i8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 42, i32 42, i32 42)
- %r4 = rocdl.smfmac.i32.16x16x64.i8 %arg7, %arg8, %arg8, %csti32, %csti32, %csti32 :
- (vector<2xi32>, vector<4xi32>, vector<4xi32>,
- i32, i32, i32) -> vector<4xi32>
-
- // CHECK: call <16 x i32> @llvm.amdgcn.smfmac.i32.32x32x32.i8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x i32> %{{.*}}, i32 42, i32 42, i32 42)
- %r5 = rocdl.smfmac.i32.32x32x32.i8 %arg7, %arg8, %arg9, %csti32, %csti32, %csti32 :
- (vector<2xi32>, vector<4xi32>, vector<16xi32>,
- i32, i32, i32) -> vector<16xi32>
-
- // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r6 = rocdl.smfmac.f32.16x16x64.bf8.bf8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
- (vector<2xi32>, vector<4xi32>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r7 = rocdl.smfmac.f32.16x16x64.bf8.fp8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
- (vector<2xi32>, vector<4xi32>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.fp8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r8 = rocdl.smfmac.f32.16x16x64.fp8.bf8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
- (vector<2xi32>, vector<4xi32>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.fp8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r9 = rocdl.smfmac.f32.16x16x64.fp8.fp8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
- (vector<2xi32>, vector<4xi32>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r10 = rocdl.smfmac.f32.32x32x32.bf8.bf8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
- (vector<2xi32>, vector<4xi32>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r11 = rocdl.smfmac.f32.32x32x32.bf8.fp8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
- (vector<2xi32>, vector<4xi32>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.fp8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r12 = rocdl.smfmac.f32.32x32x32.fp8.bf8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
- (vector<2xi32>, vector<4xi32>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.fp8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r13 = rocdl.smfmac.f32.32x32x32.fp8.fp8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
- (vector<2xi32>, vector<4xi32>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.f16(<8 x half> %{{.*}}, <16 x half> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r14 = rocdl.smfmac.f32.16x16x64.f16 %arg2, %arg10, %arg3, %csti32, %csti32, %csti32 :
- (vector<8xf16>, vector<16xf16>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.f16(<8 x half> %{{.*}}, <16 x half> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r15 = rocdl.smfmac.f32.32x32x32.f16 %arg2, %arg10, %arg4, %csti32, %csti32, %csti32 :
- (vector<8xf16>, vector<16xf16>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf16(<8 x bfloat> %{{.*}}, <16 x bfloat> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r16 = rocdl.smfmac.f32.16x16x64.bf16 %arg11, %arg12, %arg3, %csti32, %csti32, %csti32 :
- (vector<8xbf16>, vector<16xbf16>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf16(<8 x bfloat> %{{.*}}, <16 x bfloat> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r17 = rocdl.smfmac.f32.32x32x32.bf16 %arg11, %arg12, %arg4, %csti32, %csti32, %csti32 :
- (vector<8xbf16>, vector<16xbf16>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: call <4 x i32> @llvm.amdgcn.smfmac.i32.16x16x128.i8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 42, i32 42, i32 42)
- %r18 = rocdl.smfmac.i32.16x16x128.i8 %arg8, %arg13, %arg8, %csti32, %csti32, %csti32 :
- (vector<4xi32>, vector<8xi32>, vector<4xi32>,
- i32, i32, i32) -> vector<4xi32>
-
- // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.bf8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r19 = rocdl.smfmac.f32.16x16x128.bf8.bf8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
- (vector<4xi32>, vector<8xi32>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.bf8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r20 = rocdl.smfmac.f32.16x16x128.bf8.fp8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
- (vector<4xi32>, vector<8xi32>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.fp8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r21 = rocdl.smfmac.f32.16x16x128.fp8.bf8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
- (vector<4xi32>, vector<8xi32>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.fp8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r22 = rocdl.smfmac.f32.16x16x128.fp8.fp8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
- (vector<4xi32>, vector<8xi32>, vector<4xf32>,
- i32, i32, i32) -> vector<4xf32>
-
- // CHECK: call <16 x i32> @llvm.amdgcn.smfmac.i32.32x32x64.i8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x i32> %{{.*}}, i32 42, i32 42, i32 42)
- %r23 = rocdl.smfmac.i32.32x32x64.i8 %arg8, %arg13, %arg9, %csti32, %csti32, %csti32 :
- (vector<4xi32>, vector<8xi32>, vector<16xi32>,
- i32, i32, i32) -> vector<16xi32>
-
- // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.bf8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r24 = rocdl.smfmac.f32.32x32x64.bf8.bf8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
- (vector<4xi32>, vector<8xi32>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.bf8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r25 = rocdl.smfmac.f32.32x32x64.bf8.fp8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
- (vector<4xi32>, vector<8xi32>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.fp8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r26 = rocdl.smfmac.f32.32x32x64.fp8.bf8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
- (vector<4xi32>, vector<8xi32>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
-
- // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.fp8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r27 = rocdl.smfmac.f32.32x32x64.fp8.fp8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
- (vector<4xi32>, vector<8xi32>, vector<16xf32>,
- i32, i32, i32) -> vector<16xf32>
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x32.f16(<4 x half> %{{.*}}, <8 x half> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 0, i32 0)
+ %r0 = rocdl.smfmac.f32.16x16x32.f16 %arg1, %arg2, %arg3, %index, 0, 0 :
+ (vector<4xf16>, vector<8xf16>, vector<4xf32>, i32) -> vector<4xf32>
+
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x16.f16(<4 x half> %{{.*}}, <8 x half> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 0, i32 0)
+ %r1 = rocdl.smfmac.f32.32x32x16.f16 %arg1, %arg2, %arg4, %index, 0, 0 :
+ (vector<4xf16>, vector<8xf16>, vector<16xf32>, i32) -> vector<16xf32>
+
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x32.bf16(<4 x i16> %{{.*}}, <8 x i16> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 0, i32 0)
+ %r2 = rocdl.smfmac.f32.16x16x32.bf16 %arg5, %arg6, %arg3, %index, 0, 0 :
+ (vector<4xi16>, vector<8xi16>, vector<4xf32>, i32) -> vector<4xf32>
+
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x16.bf16(<4 x i16> %{{.*}}, <8 x i16> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 0, i32 0)
+ %r3 = rocdl.smfmac.f32.32x32x16.bf16 %arg5, %arg6, %arg4, %index, 0, 0 :
+ (vector<4xi16>, vector<8xi16>, vector<16xf32>, i32) -> vector<16xf32>
+
+ // CHECK: call <4 x i32> @llvm.amdgcn.smfmac.i32.16x16x64.i8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 42, i32 0, i32 0)
+ %r4 = rocdl.smfmac.i32.16x16x64.i8 %arg7, %arg8, %arg8, %index, 0, 0 :
+ (vector<2xi32>, vector<4xi32>, vector<4xi32>, i32) -> vector<4xi32>
+
+ // CHECK: call <16 x i32> @llvm.amdgcn.smfmac.i32.32x32x32.i8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x i32> %{{.*}}, i32 42, i32 0, i32 0)
+ %r5 = rocdl.smfmac.i32.32x32x32.i8 %arg7, %arg8, %arg9, %index, 0, 0 :
+ (vector<2xi32>, vector<4xi32>, vector<16xi32>, i32) -> vector<16xi32>
+
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 0, i32 0)
+ %r6 = rocdl.smfmac.f32.16x16x64.bf8.bf8 %arg7, %arg8, %arg3, %index, 0, 0 :
+ (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32) -> vector<4xf32>
+
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 0, i32 0)
+ %r7 = rocdl.smfmac.f32.16x16x64.bf8.fp8 %arg7, %arg8, %arg3, %index, 0, 0 :
+ (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32) -> vector<4xf32>
+
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.fp8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 0, i32 0)
+ %r8 = rocdl.smfmac.f32.16x16x64.fp8.bf8 %arg7, %arg8, %arg3, %index, 0, 0 :
+ (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32) -> vector<4xf32>
+
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.fp8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 0, i32 0)
+ %r9 = rocdl.smfmac.f32.16x16x64.fp8.fp8 %arg7, %arg8, %arg3, %index, 0, 0 :
+ (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32) -> vector<4xf32>
+
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 0, i32 0)
+ %r10 = rocdl.smfmac.f32.32x32x32.bf8.bf8 %arg7, %arg8, %arg4, %index, 0, 0 :
+ (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32) -> vector<16xf32>
+
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 0, i32 0)
+ %r11 = rocdl.smfmac.f32.32x32x32.bf8.fp8 %arg7, %arg8, %arg4, %index, 0, 0 :
+ (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32) -> vector<16xf32>
+
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.fp8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 0, i32 0)
+ %r12 = rocdl.smfmac.f32.32x32x32.fp8.bf8 %arg7, %arg8, %arg4, %index, 0, 0 :
+ (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32) -> vector<16xf32>
+
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.fp8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 0, i32 0)
+ %r13 = rocdl.smfmac.f32.32x32x32.fp8.fp8 %arg7, %arg8, %arg4, %index, 0, 0 :
+ (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32) -> vector<16xf32>
+
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.f16(<8 x half> %{{.*}}, <16 x half> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 0, i32 0)
+ %r14 = rocdl.smfmac.f32.16x16x64.f16 %arg2, %arg10, %arg3, %index, 0, 0 :
+ (vector<8xf16>, vector<16xf16>, vector<4xf32>, i32) -> vector<4xf32>
+
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.f16(<8 x half> %{{.*}}, <16 x half> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 0, i32 0)
+ %r15 = rocdl.smfmac.f32.32x32x32.f16 %arg2, %arg10, %arg4, %index, 0, 0 :
+ (vector<8xf16>, vector<16xf16>, vector<16xf32>, i32) -> vector<16xf32>
+
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf16(<8 x bfloat> %{{.*}}, <16 x bfloat> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 0, i32 0)
+ %r16 = rocdl.smfmac.f32.16x16x64.bf16 %arg11, %arg12, %arg3, %index, 0, 0 :
+ (vector<8xbf16>, vector<16xbf16>, vector<4xf32>, i32) -> vector<4xf32>
+
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf16(<8 x bfloat> %{{.*}}, <16 x bfloat> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 0, i32 0)
+ %r17 = rocdl.smfmac.f32.32x32x32.bf16 %arg11, %arg12, %arg4, %index, 0, 0 :
+ (vector<8xbf16>, vector<16xbf16>, vector<16xf32>, i32) -> vector<16xf32>
+
+ // CHECK: call <4 x i32> @llvm.amdgcn.smfmac.i32.16x16x128.i8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 42, i32 0, i32 0)
+ %r18 = rocdl.smfmac.i32.16x16x128.i8 %arg8, %arg13, %arg8, %index, 0, 0 :
+ (vector<4xi32>, vector<8xi32>, vector<4xi32>, i32) -> vector<4xi32>
+
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.bf8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 0, i32 0)
+ %r19 = rocdl.smfmac.f32.16x16x128.bf8.bf8 %arg8, %arg13, %arg3, %index, 0, 0 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32) -> vector<4xf32>
+
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.bf8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 0, i32 0)
+ %r20 = rocdl.smfmac.f32.16x16x128.bf8.fp8 %arg8, %arg13, %arg3, %index, 0, 0 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32) -> vector<4xf32>
+
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.fp8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 0, i32 0)
+ %r21 = rocdl.smfmac.f32.16x16x128.fp8.bf8 %arg8, %arg13, %arg3, %index, 0, 0 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32) -> vector<4xf32>
+
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.fp8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 0, i32 0)
+ %r22 = rocdl.smfmac.f32.16x16x128.fp8.fp8 %arg8, %arg13, %arg3, %index, 0, 0 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32) -> vector<4xf32>
+
+ // CHECK: call <16 x i32> @llvm.amdgcn.smfmac.i32.32x32x64.i8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x i32> %{{.*}}, i32 42, i32 0, i32 0)
+ %r23 = rocdl.smfmac.i32.32x32x64.i8 %arg8, %arg13, %arg9, %index, 0, 0 :
+ (vector<4xi32>, vector<8xi32>, vector<16xi32>, i32) -> vector<16xi32>
+
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.bf8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 0, i32 0)
+ %r24 = rocdl.smfmac.f32.32x32x64.bf8.bf8 %arg8, %arg13, %arg4, %index, 0, 0 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32) -> vector<16xf32>
+
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.bf8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 0, i32 0)
+ %r25 = rocdl.smfmac.f32.32x32x64.bf8.fp8 %arg8, %arg13, %arg4, %index, 0, 0 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32) -> vector<16xf32>
+
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.fp8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 0, i32 0)
+ %r26 = rocdl.smfmac.f32.32x32x64.fp8.bf8 %arg8, %arg13, %arg4, %index, 0, 0 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32) -> vector<16xf32>
+
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.fp8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 0, i32 0)
+ %r27 = rocdl.smfmac.f32.32x32x64.fp8.fp8 %arg8, %arg13, %arg4, %index, 0, 0 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32) -> vector<16xf32>
llvm.return %r0 : vector<4 x f32>
}
>From 15f55b4f8493a814453cc9578ba5de22cf7343f4 Mon Sep 17 00:00:00 2001
From: ravil-mobile <ravil.aviva.com at gmail.com>
Date: Tue, 13 Jan 2026 17:29:22 +0000
Subject: [PATCH 4/6] [ROCDL] Added type and structual constrains to MFMA_scale
Ops
---
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 39 ++-
mlir/test/Dialect/LLVMIR/rocdl.mlir | 310 +++++++++----------
mlir/test/Target/LLVMIR/rocdl.mlir | 210 ++++++-------
3 files changed, 274 insertions(+), 285 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 9ad7672e59cdb..14a56d2ed97e0 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -625,12 +625,30 @@ class ROCDL_Mfma_IntrOp<string mnemonic, Type AB, Type CD> :
}];
}
+class ROCDL_Mfma_Scale_IntrOp<string mnemonic, Type AB, Type CD> :
+ ROCDL_IntrOp<mnemonic, [], [0, 1], [], 1, 0, 0, 0, [3, 4, 5, 7], ["cbsz", "blgp", "opselA", "opselB"]>,
+ Arguments<(ins
+ LLVM_VectorOf<AB>:$a,
+ LLVM_VectorOf<AB>:$b,
+ LLVM_VectorOf<CD>:$c,
+ I32Attr:$cbsz,
+ I32Attr:$blgp,
+ I32Attr:$opselA,
+ I32:$scaleA,
+ I32Attr:$opselB,
+ I32:$scaleB)> {
+ let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+ let assemblyFormat = [{
+ $a `,` $b `,` $c `,` $cbsz `,` $blgp `,` $opselA `,` $scaleA `,` $opselB `,` $scaleB attr-dict `:` functional-type(operands, $res)
+ }];
+}
+
class ROCDL_Smfmac_IntrOp<string mnemonic, Type AB, Type CD> :
ROCDL_IntrOp<mnemonic, [], [], [], 1, 0, 0, 0, [4, 5], ["cbsz", "abid"]>,
Arguments<(ins
- LLVM_ScalarOrVectorOf<AB>:$a,
- LLVM_ScalarOrVectorOf<AB>:$b,
- LLVM_ScalarOrVectorOf<CD>:$c,
+ LLVM_VectorOf<AB>:$a,
+ LLVM_VectorOf<AB>:$b,
+ LLVM_VectorOf<CD>:$c,
I32:$index,
I32Attr:$cbsz,
I32Attr:$abid)> {
@@ -640,16 +658,6 @@ class ROCDL_Smfmac_IntrOp<string mnemonic, Type AB, Type CD> :
}];
}
-//===---------------------------------------------------------------------===//
-// MFMA intrinsics with overloaded operands
-class ROCDL_Mfma_OO_IntrOp<string mnemonic, list<int> overloadedOperands,
- list<Trait> traits = []> :
- ROCDL_IntrOp<mnemonic, [], overloadedOperands, traits, 1>,
- Arguments<(ins Variadic<LLVM_Type>:$args)> {
- let assemblyFormat =
- "$args attr-dict `:` functional-type($args, $res)";
-}
-
// Available on all CDNA.
def ROCDL_mfma_f32_32x32x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x1f32", /*Type AB=*/F32, /*Type CD=*/F32>;
def ROCDL_mfma_f32_16x16x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x1f32", F32, F32>;
@@ -701,8 +709,9 @@ def ROCDL_mfma_f32_16x16x32_f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.f16", F16
def ROCDL_mfma_f32_32x32x16_bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf16", BF16, F32>;
def ROCDL_mfma_i32_32x32x32_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x32.i8", I32, I32>;
def ROCDL_mfma_f32_32x32x16_f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.f16", F16, F32>;
-def ROCDL_mfma_scale_f32_16x16x128_f8f6f4 : ROCDL_Mfma_OO_IntrOp<"mfma.scale.f32.16x16x128.f8f6f4", [0,1]>;
-def ROCDL_mfma_scale_f32_32x32x64_f8f6f4 : ROCDL_Mfma_OO_IntrOp<"mfma.scale.f32.32x32x64.f8f6f4", [0,1]>;
+
+def ROCDL_mfma_scale_f32_16x16x128_f8f6f4 : ROCDL_Mfma_Scale_IntrOp<"mfma.scale.f32.16x16x128.f8f6f4", I32, F32>;
+def ROCDL_mfma_scale_f32_32x32x64_f8f6f4 : ROCDL_Mfma_Scale_IntrOp<"mfma.scale.f32.32x32x64.f8f6f4", I32, F32>;
// 2:4 Sparsity ops (GFX94x)
def ROCDL_smfmac_f32_16x16x32_f16 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x32.f16", F16, F32>;
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index f7aa1cdf5bb9f..dd4e66d50c4a7 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -430,137 +430,132 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
llvm.func @rocdl.mfma.scale.f32.32x32x64.f8f6f4(%arg0 : i32,
%arg1 : vector<16 x f32>, %arg2 : vector<8xi32>,
%arg3 : vector<6xi32>, %arg4 : vector<4xi32>) {
- %cst0 = llvm.mlir.constant(0 : i32) : i32
- %cst1 = llvm.mlir.constant(1 : i32) : i32
- %cst2 = llvm.mlir.constant(2 : i32) : i32
- %cst3 = llvm.mlir.constant(3 : i32) : i32
- %cst4 = llvm.mlir.constant(4 : i32) : i32
// CHECK-LABEL: rocdl.mfma.scale.f32.32x32x64.f8f6f4
// fp8 * fp8
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- %r00 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg2, %arg1, %cst0, %cst0, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
+ %r00 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg2, %arg1, 0, 0, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// fp8 * bf8
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- %r01 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg2, %arg1, %cst0, %cst1, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
+ %r01 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg2, %arg1, 0, 1, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// fp8 * fp6
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- %r02 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg3, %arg1, %cst0, %cst2, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
+ %r02 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg3, %arg1, 0, 2, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// fp8 * bf6
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- %r03 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg3, %arg1, %cst0, %cst3, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
+ %r03 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg3, %arg1, 0, 3, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// fp8 * fp4
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<8xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- %r04 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg4, %arg1, %cst0, %cst4, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<8xi32>, vector<4xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
+ %r04 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg4, %arg1, 0, 4, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<4xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// bf8 * fp8
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- %r10 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg2, %arg1, %cst1, %cst0, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
+ %r10 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg2, %arg1, 1, 0, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// bf8 * bf8
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- %r11 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg2, %arg1, %cst1, %cst1, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
+ %r11 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg2, %arg1, 1, 1, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// bf8 * fp6
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- %r12 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg3, %arg1, %cst1, %cst2, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
+ %r12 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg3, %arg1, 1, 2, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// bf8 * bf6
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- %r13 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg3, %arg1, %cst1, %cst3, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
+ %r13 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg3, %arg1, 1, 3, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// bf8 * fp4
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<8xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- %r14 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg4, %arg1, %cst1, %cst4, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<8xi32>, vector<4xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
+ %r14 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg4, %arg1, 1, 4, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<4xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// fp6 * fp8
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- %r20 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg2, %arg1, %cst2, %cst0, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
+ %r20 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg2, %arg1, 2, 0, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// fp6 * bf8
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- %r21 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg2, %arg1, %cst2, %cst1, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
+ %r21 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg2, %arg1, 2, 1, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// fp6 * fp6
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- %r22 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg3, %arg1, %cst2, %cst2, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
+ %r22 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg3, %arg1, 2, 2, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// fp6 * bf6
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- %r23 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg3, %arg1, %cst2, %cst3, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
+ %r23 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg3, %arg1, 2, 3, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// fp6 * fp4
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<6xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- %r24 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg4, %arg1, %cst2, %cst4, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<6xi32>, vector<4xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
+ %r24 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg4, %arg1, 2, 4, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<4xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// bf6 * fp8
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- %r30 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg2, %arg1, %cst3, %cst0, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
+ %r30 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg2, %arg1, 3, 0, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// bf6 * bf8
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- %r31 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg2, %arg1, %cst3, %cst1, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
+ %r31 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg2, %arg1, 3, 1, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// bf6 * fp6
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- %r32 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg3, %arg1, %cst3, %cst2, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
+ %r32 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg3, %arg1, 3, 2, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// bf6 * bf6
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- %r33 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg3, %arg1, %cst3, %cst3, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
+ %r33 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg3, %arg1, 3, 3, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// bf6 * fp4
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<6xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- %r34 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg4, %arg1, %cst3, %cst4, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<6xi32>, vector<4xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
+ %r34 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg4, %arg1, 3, 4, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<4xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// fp4 * fp8
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- %r40 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg2, %arg1, %cst4, %cst0, %cst0, %arg0, %cst0, %arg0 :
- (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
+ %r40 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg2, %arg1, 4, 0, 0, %arg0, 0, %arg0 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// fp4 * bf8
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- %r41 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg2, %arg1, %cst4, %cst1, %cst0, %arg0, %cst0, %arg0 :
- (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
+ %r41 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg2, %arg1, 4, 1, 0, %arg0, 0, %arg0 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// fp4 * fp6
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<4xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- %r42 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg3, %arg1, %cst4, %cst2, %cst0, %arg0, %cst0, %arg0 :
- (vector<4xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<4xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
+ %r42 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg3, %arg1, 4, 2, 0, %arg0, 0, %arg0 :
+ (vector<4xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// fp4 * bf6
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<4xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- %r43 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg3, %arg1, %cst4, %cst3, %cst0, %arg0, %cst0, %arg0 :
- (vector<4xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<4xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
+ %r43 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg3, %arg1, 4, 3, 0, %arg0, 0, %arg0 :
+ (vector<4xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// fp4 * fp4
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- %r44 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg4, %arg1, %cst4, %cst4, %cst0, %arg0, %cst0, %arg0 :
- (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}} : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
+ %r44 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg4, %arg1, 4, 4, 0, %arg0, 0, %arg0 :
+ (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
llvm.return
}
@@ -568,137 +563,132 @@ llvm.func @rocdl.mfma.scale.f32.32x32x64.f8f6f4(%arg0 : i32,
llvm.func @rocdl.mfma.scale.f32.16x16x128.f8f6f4(%arg0 : i32,
%arg1 : vector<4 x f32>, %arg2 : vector<8xi32>,
%arg3 : vector<6xi32>, %arg4 : vector<4xi32>) {
- %cst0 = llvm.mlir.constant(0 : i32) : i32
- %cst1 = llvm.mlir.constant(1 : i32) : i32
- %cst2 = llvm.mlir.constant(2 : i32) : i32
- %cst3 = llvm.mlir.constant(3 : i32) : i32
- %cst4 = llvm.mlir.constant(4 : i32) : i32
// CHECK-LABEL: rocdl.mfma.scale.f32.16x16x128.f8f6f4
// fp8 * fp8
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- %r00 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg2, %arg1, %cst0, %cst0, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %r00 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg2, %arg1, 0, 0, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// fp8 * bf8
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- %r01 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg2, %arg1, %cst0, %cst1, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %r01 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg2, %arg1, 0, 1, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// fp8 * fp6
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- %r02 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg3, %arg1, %cst0, %cst2, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %r02 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg3, %arg1, 0, 2, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// fp8 * bf6
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- %r03 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg3, %arg1, %cst0, %cst3, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %r03 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg3, %arg1, 0, 3, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// fp8 * fp4
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<8xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- %r04 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg4, %arg1, %cst0, %cst4, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<8xi32>, vector<4xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %r04 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg4, %arg1, 0, 4, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<4xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// bf8 * fp8
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- %r10 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg2, %arg1, %cst1, %cst0, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %r10 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg2, %arg1, 1, 0, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// bf8 * bf8
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- %r11 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg2, %arg1, %cst1, %cst1, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %r11 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg2, %arg1, 1, 1, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// bf8 * fp6
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- %r12 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg3, %arg1, %cst1, %cst2, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %r12 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg3, %arg1, 1, 2, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// bf8 * bf6
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- %r13 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg3, %arg1, %cst1, %cst3, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %r13 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg3, %arg1, 1, 3, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// bf8 * fp4
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<8xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- %r14 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg4, %arg1, %cst1, %cst4, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<8xi32>, vector<4xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %r14 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg4, %arg1, 1, 4, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<4xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// fp6 * fp8
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- %r20 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg2, %arg1, %cst2, %cst0, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %r20 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg2, %arg1, 2, 0, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// fp6 * bf8
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- %r21 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg2, %arg1, %cst2, %cst1, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %r21 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg2, %arg1, 2, 1, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// fp6 * fp6
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- %r22 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg3, %arg1, %cst2, %cst2, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %r22 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg3, %arg1, 2, 2, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// fp6 * bf6
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- %r23 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg3, %arg1, %cst2, %cst3, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %r23 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg3, %arg1, 2, 3, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// fp6 * fp4
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<6xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- %r24 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg4, %arg1, %cst2, %cst4, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<6xi32>, vector<4xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %r24 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg4, %arg1, 2, 4, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<4xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// bf6 * fp8
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- %r30 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg2, %arg1, %cst3, %cst0, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %r30 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg2, %arg1, 3, 0, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// bf6 * bf8
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- %r31 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg2, %arg1, %cst3, %cst1, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %r31 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg2, %arg1, 3, 1, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// bf6 * fp6
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- %r32 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg3, %arg1, %cst3, %cst2, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %r32 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg3, %arg1, 3, 2, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// bf6 * bf6
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- %r33 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg3, %arg1, %cst3, %cst3, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %r33 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg3, %arg1, 3, 3, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// bf6 * fp4
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<6xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- %r34 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg4, %arg1, %cst3, %cst4, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<6xi32>, vector<4xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %r34 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg4, %arg1, 3, 4, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<4xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// fp4 * fp8
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- %r40 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4, %arg2, %arg1, %cst4, %cst0, %cst0, %arg0, %cst0, %arg0 :
- (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %r40 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4, %arg2, %arg1, 4, 0, 0, %arg0, 0, %arg0 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// fp4 * bf8
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- %r41 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4, %arg2, %arg1, %cst4, %cst1, %cst0, %arg0, %cst0, %arg0 :
- (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %r41 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4, %arg2, %arg1, 4, 1, 0, %arg0, 0, %arg0 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// fp4 * fp6
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<4xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- %r42 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4, %arg3, %arg1, %cst4, %cst2, %cst0, %arg0, %cst0, %arg0 :
- (vector<4xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<4xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %r42 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4, %arg3, %arg1, 4, 2, 0, %arg0, 0, %arg0 :
+ (vector<4xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// fp4 * bf6
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<4xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- %r43 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4, %arg3, %arg1, %cst4, %cst3, %cst0, %arg0, %cst0, %arg0 :
- (vector<4xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<4xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %r43 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4, %arg3, %arg1, 4, 3, 0, %arg0, 0, %arg0 :
+ (vector<4xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// fp4 * fp4
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- %r44 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4, %arg4, %arg1, %cst4, %cst4, %cst0, %arg0, %cst0, %arg0 :
- (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}} : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %r44 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4, %arg4, %arg1, 4, 4, 0, %arg0, 0, %arg0 :
+ (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
llvm.return
}
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index d9d766bc01a17..73daa5fd854a1 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -688,137 +688,132 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
llvm.func @rocdl.mfma.scale.f32.32x32x64.f8f6f4(%arg0 : i32,
%arg1 : vector<16 x f32>, %arg2 : vector<8xi32>,
%arg3 : vector<6xi32>, %arg4 : vector<4xi32>) -> vector<16 x f32> {
- %cst0 = llvm.mlir.constant(0 : i32) : i32
- %cst1 = llvm.mlir.constant(1 : i32) : i32
- %cst2 = llvm.mlir.constant(2 : i32) : i32
- %cst3 = llvm.mlir.constant(3 : i32) : i32
- %cst4 = llvm.mlir.constant(4 : i32) : i32
// CHECK-LABEL: rocdl.mfma.scale.f32.32x32x64.f8f6f4
// fp8 * fp8
// CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v8i32(<8 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 0, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r00 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg2, %arg1, %cst0, %cst0, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ %r00 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg2, %arg1, 0, 0, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// fp8 * bf8
// CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v8i32(<8 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 0, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r01 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg2, %arg1, %cst0, %cst1, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ %r01 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg2, %arg1, 0, 1, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// fp8 * fp6
// CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v6i32(<8 x i32> %{{.*}}, <6 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 0, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r02 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg3, %arg1, %cst0, %cst2, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ %r02 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg3, %arg1, 0, 2, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// fp8 * bf6
// CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v6i32(<8 x i32> %{{.*}}, <6 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 0, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r03 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg3, %arg1, %cst0, %cst3, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ %r03 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg3, %arg1, 0, 3, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// fp8 * fp4
// CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v4i32(<8 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 0, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r04 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg4, %arg1, %cst0, %cst4, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ %r04 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg4, %arg1, 0, 4, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<4xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// bf8 * fp8
// CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v8i32(<8 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 1, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r10 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg2, %arg1, %cst1, %cst0, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ %r10 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg2, %arg1, 1, 0, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// bf8 * bf8
// CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v8i32(<8 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 1, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r11 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg2, %arg1, %cst1, %cst1, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ %r11 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg2, %arg1, 1, 1, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// bf8 * fp6
// CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v6i32(<8 x i32> %{{.*}}, <6 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 1, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r12 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg3, %arg1, %cst1, %cst2, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ %r12 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg3, %arg1, 1, 2, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// bf8 * bf6
// CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v6i32(<8 x i32> %{{.*}}, <6 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 1, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r13 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg3, %arg1, %cst1, %cst3, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ %r13 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg3, %arg1, 1, 3, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// bf8 * fp4
// CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v4i32(<8 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 1, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r14 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg4, %arg1, %cst1, %cst4, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ %r14 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg4, %arg1, 1, 4, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<4xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// fp6 * fp8
// CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v8i32(<6 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 2, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r20 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg2, %arg1, %cst2, %cst0, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ %r20 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg2, %arg1, 2, 0, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// fp6 * bf8
// CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v8i32(<6 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 2, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r21 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg2, %arg1, %cst2, %cst1, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ %r21 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg2, %arg1, 2, 1, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// fp6 * fp6
// CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v6i32(<6 x i32> %{{.*}}, <6 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 2, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r22 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg3, %arg1, %cst2, %cst2, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ %r22 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg3, %arg1, 2, 2, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// fp6 * bf6
// CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v6i32(<6 x i32> %{{.*}}, <6 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 2, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r23 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg3, %arg1, %cst2, %cst3, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ %r23 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg3, %arg1, 2, 3, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// fp6 * fp4
// CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v4i32(<6 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 2, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r24 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg4, %arg1, %cst2, %cst4, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ %r24 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg4, %arg1, 2, 4, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<4xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// bf6 * fp8
// CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v8i32(<6 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 3, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r30 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg2, %arg1, %cst3, %cst0, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ %r30 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg2, %arg1, 3, 0, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// bf6 * bf8
// CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v8i32(<6 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 3, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r31 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg2, %arg1, %cst3, %cst1, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ %r31 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg2, %arg1, 3, 1, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// bf6 * fp6
// CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v6i32(<6 x i32> %{{.*}}, <6 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 3, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r32 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg3, %arg1, %cst3, %cst2, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ %r32 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg3, %arg1, 3, 2, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// bf6 * bf6
// CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v6i32(<6 x i32> %{{.*}}, <6 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 3, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r33 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg3, %arg1, %cst3, %cst3, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ %r33 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg3, %arg1, 3, 3, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// bf6 * fp4
// CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v4i32(<6 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 3, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r34 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg4, %arg1, %cst3, %cst4, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ %r34 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg4, %arg1, 3, 4, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<4xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// fp4 * fp8
// CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v4i32.v8i32(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 4, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r40 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg2, %arg1, %cst4, %cst0, %cst0, %arg0, %cst0, %arg0 :
- (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ %r40 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg2, %arg1, 4, 0, 0, %arg0, 0, %arg0 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// fp4 * bf8
// CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v4i32.v8i32(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 4, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r41 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg2, %arg1, %cst4, %cst1, %cst0, %arg0, %cst0, %arg0 :
- (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ %r41 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg2, %arg1, 4, 1, 0, %arg0, 0, %arg0 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// fp4 * fp6
// CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v4i32.v6i32(<4 x i32> %{{.*}}, <6 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 4, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r42 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg3, %arg1, %cst4, %cst2, %cst0, %arg0, %cst0, %arg0 :
- (vector<4xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ %r42 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg3, %arg1, 4, 2, 0, %arg0, 0, %arg0 :
+ (vector<4xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// fp4 * bf6
// CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v4i32.v6i32(<4 x i32> %{{.*}}, <6 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 4, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r43 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg3, %arg1, %cst4, %cst3, %cst0, %arg0, %cst0, %arg0 :
- (vector<4xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ %r43 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg3, %arg1, 4, 3, 0, %arg0, 0, %arg0 :
+ (vector<4xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
// fp4 * fp4
// CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v4i32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 4, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r44 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg4, %arg1, %cst4, %cst4, %cst0, %arg0, %cst0, %arg0 :
- (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ %r44 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg4, %arg1, 4, 4, 0, %arg0, 0, %arg0 :
+ (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
llvm.return %r00 : vector<16 x f32>
}
@@ -826,137 +821,132 @@ llvm.func @rocdl.mfma.scale.f32.32x32x64.f8f6f4(%arg0 : i32,
llvm.func @rocdl.mfma.scale.f32.16x16x128.f8f6f4(%arg0 : i32,
%arg1 : vector<4 x f32>, %arg2 : vector<8xi32>,
%arg3 : vector<6xi32>, %arg4 : vector<4xi32>) -> vector<4 x f32> {
- %cst0 = llvm.mlir.constant(0 : i32) : i32
- %cst1 = llvm.mlir.constant(1 : i32) : i32
- %cst2 = llvm.mlir.constant(2 : i32) : i32
- %cst3 = llvm.mlir.constant(3 : i32) : i32
- %cst4 = llvm.mlir.constant(4 : i32) : i32
// CHECK-LABEL: rocdl.mfma.scale.f32.16x16x128.f8f6f4
// fp8 * fp8
// CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 0, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r00 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg2, %arg1, %cst0, %cst0, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ %r00 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg2, %arg1, 0, 0, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// fp8 * bf8
// CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 0, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r01 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg2, %arg1, %cst0, %cst1, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ %r01 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg2, %arg1, 0, 1, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// fp8 * fp6
// CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v6i32(<8 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 0, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r02 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg3, %arg1, %cst0, %cst2, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ %r02 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg3, %arg1, 0, 2, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// fp8 * bf6
// CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v6i32(<8 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 0, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r03 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg3, %arg1, %cst0, %cst3, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ %r03 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg3, %arg1, 0, 3, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// fp8 * fp4
// CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v4i32(<8 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 0, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r04 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg4, %arg1, %cst0, %cst4, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ %r04 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg4, %arg1, 0, 4, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<4xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// bf8 * fp8
// CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 1, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r10 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg2, %arg1, %cst1, %cst0, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ %r10 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg2, %arg1, 1, 0, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// bf8 * bf8
// CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 1, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r11 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg2, %arg1, %cst1, %cst1, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ %r11 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg2, %arg1, 1, 1, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// bf8 * fp6
// CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v6i32(<8 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 1, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r12 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg3, %arg1, %cst1, %cst2, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ %r12 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg3, %arg1, 1, 2, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// bf8 * bf6
// CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v6i32(<8 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 1, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r13 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg3, %arg1, %cst1, %cst3, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ %r13 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg3, %arg1, 1, 3, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// bf8 * fp4
// CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v4i32(<8 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 1, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r14 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg4, %arg1, %cst1, %cst4, %cst0, %arg0, %cst0, %arg0 :
- (vector<8xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ %r14 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg4, %arg1, 1, 4, 0, %arg0, 0, %arg0 :
+ (vector<8xi32>, vector<4xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// fp6 * fp8
// CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v8i32(<6 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 2, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r20 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg2, %arg1, %cst2, %cst0, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ %r20 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg2, %arg1, 2, 0, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// fp6 * bf8
// CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v8i32(<6 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 2, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r21 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg2, %arg1, %cst2, %cst1, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ %r21 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg2, %arg1, 2, 1, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// fp6 * fp6
// CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v6i32(<6 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 2, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r22 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg3, %arg1, %cst2, %cst2, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ %r22 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg3, %arg1, 2, 2, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// fp6 * bf6
// CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v6i32(<6 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 2, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r23 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg3, %arg1, %cst2, %cst3, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ %r23 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg3, %arg1, 2, 3, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// fp6 * fp4
// CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v4i32(<6 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 2, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r24 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg4, %arg1, %cst2, %cst4, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ %r24 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg4, %arg1, 2, 4, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<4xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// bf6 * fp8
// CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v8i32(<6 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 3, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r30 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg2, %arg1, %cst3, %cst0, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ %r30 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg2, %arg1, 3, 0, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// bf6 * bf8
// CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v8i32(<6 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 3, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r31 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg2, %arg1, %cst3, %cst1, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ %r31 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg2, %arg1, 3, 1, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// bf6 * fp6
// CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v6i32(<6 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 3, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r32 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg3, %arg1, %cst3, %cst2, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ %r32 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg3, %arg1, 3, 2, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// bf6 * bf6
// CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v6i32(<6 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 3, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r33 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg3, %arg1, %cst3, %cst3, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ %r33 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg3, %arg1, 3, 3, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// bf6 * fp4
// CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v4i32(<6 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 3, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r34 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg4, %arg1, %cst3, %cst4, %cst0, %arg0, %cst0, %arg0 :
- (vector<6xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ %r34 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg4, %arg1, 3, 4, 0, %arg0, 0, %arg0 :
+ (vector<6xi32>, vector<4xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// fp4 * fp8
// CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v4i32.v8i32(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 4, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r40 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4, %arg2, %arg1, %cst4, %cst0, %cst0, %arg0, %cst0, %arg0 :
- (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ %r40 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4, %arg2, %arg1, 4, 0, 0, %arg0, 0, %arg0 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// fp4 * bf8
// CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v4i32.v8i32(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 4, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r41 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4, %arg2, %arg1, %cst4, %cst1, %cst0, %arg0, %cst0, %arg0 :
- (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ %r41 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4, %arg2, %arg1, 4, 1, 0, %arg0, 0, %arg0 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// fp4 * fp6
// CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v4i32.v6i32(<4 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 4, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r42 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4, %arg3, %arg1, %cst4, %cst2, %cst0, %arg0, %cst0, %arg0 :
- (vector<4xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ %r42 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4, %arg3, %arg1, 4, 2, 0, %arg0, 0, %arg0 :
+ (vector<4xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// fp4 * bf6
// CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v4i32.v6i32(<4 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 4, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
- %r43 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4, %arg3, %arg1, %cst4, %cst3, %cst0, %arg0, %cst0, %arg0 :
- (vector<4xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ %r43 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4, %arg3, %arg1, 4, 3, 0, %arg0, 0, %arg0 :
+ (vector<4xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
// fp4 * fp4
// CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v4i32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 4, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}}
- %r44 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4, %arg4, %arg1, %cst4, %cst4, %cst0, %arg0, %cst0, %arg0 :
- (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ %r44 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4, %arg4, %arg1, 4, 4, 0, %arg0, 0, %arg0 :
+ (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
llvm.return %r00 : vector<4 x f32>
}
>From 249a9b9955ac0c0520c94b6a9282a4e887a254fd Mon Sep 17 00:00:00 2001
From: ravil-mobile <ravil.aviva.com at gmail.com>
Date: Wed, 14 Jan 2026 11:06:48 +0000
Subject: [PATCH 5/6] [ROCDL] used concrete vector types for all mfma ops
---
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 164 +++++++++----------
1 file changed, 82 insertions(+), 82 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 14a56d2ed97e0..2ae30eb42ac84 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -610,16 +610,16 @@ def ROCDL_IglpOpt : ROCDL_ConcreteNonMemIntrOp<"iglp.opt", [], 0, [0], ["variant
//===---------------------------------------------------------------------===//
// Xdlops intrinsics
-class ROCDL_Mfma_IntrOp<string mnemonic, Type AB, Type CD> :
+class ROCDL_Mfma_IntrOp<string mnemonic, Type ABType, Type CDType> :
ROCDL_IntrOp<mnemonic, [], [], [], 1, 0, 0, 0, [3, 4, 5], ["cbsz", "abid", "blgp"]>,
Arguments<(ins
- LLVM_ScalarOrVectorOf<AB>:$a,
- LLVM_ScalarOrVectorOf<AB>:$b,
- LLVM_ScalarOrVectorOf<CD>:$c,
+ ABType:$a,
+ ABType:$b,
+ CDType:$c,
I32Attr:$cbsz,
I32Attr:$abid,
I32Attr:$blgp)> {
- let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+ let results = (outs CDType:$res);
let assemblyFormat = [{
$a `,` $b `,` $c `,` $cbsz `,` $abid `,` $blgp attr-dict `:` functional-type(operands, $res)
}];
@@ -643,106 +643,106 @@ class ROCDL_Mfma_Scale_IntrOp<string mnemonic, Type AB, Type CD> :
}];
}
-class ROCDL_Smfmac_IntrOp<string mnemonic, Type AB, Type CD> :
+class ROCDL_Smfmac_IntrOp<string mnemonic, Type AType, Type BType, Type CDType> :
ROCDL_IntrOp<mnemonic, [], [], [], 1, 0, 0, 0, [4, 5], ["cbsz", "abid"]>,
Arguments<(ins
- LLVM_VectorOf<AB>:$a,
- LLVM_VectorOf<AB>:$b,
- LLVM_VectorOf<CD>:$c,
+ AType:$a,
+ BType:$b,
+ CDType:$c,
I32:$index,
I32Attr:$cbsz,
I32Attr:$abid)> {
- let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+ let results = (outs CDType:$res);
let assemblyFormat = [{
$a `,` $b `,` $c `,` $index `,` $cbsz `,` $abid attr-dict `:` functional-type(operands, $res)
}];
}
// Available on all CDNA.
-def ROCDL_mfma_f32_32x32x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x1f32", /*Type AB=*/F32, /*Type CD=*/F32>;
-def ROCDL_mfma_f32_16x16x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x1f32", F32, F32>;
-def ROCDL_mfma_f32_4x4x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x1f32", F32, F32>;
-def ROCDL_mfma_f32_32x32x2f32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x2f32", F32, F32>;
-def ROCDL_mfma_f32_16x16x4f32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x4f32", F32, F32>;
-def ROCDL_mfma_f32_32x32x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4f16", F16, F32>;
-def ROCDL_mfma_f32_16x16x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x4f16", F16, F32>;
-def ROCDL_mfma_f32_4x4x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x4f16", F16, F32>;
-def ROCDL_mfma_f32_32x32x8f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x8f16", F16, F32>;
-def ROCDL_mfma_f32_16x16x16f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x16f16", F16, F32>;
-def ROCDL_mfma_i32_32x32x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x4i8", I32, I32>;
-def ROCDL_mfma_i32_16x16x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x4i8", I32, I32>;
-def ROCDL_mfma_i32_4x4x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.4x4x4i8", I32, I32>;
-def ROCDL_mfma_i32_32x32x8i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x8i8", I32, I32>;
-def ROCDL_mfma_i32_16x16x16i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x16i8", I32, I32>;
-def ROCDL_mfma_f32_32x32x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x2bf16", I16, F32>;
-def ROCDL_mfma_f32_16x16x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x2bf16", I16, F32>;
-def ROCDL_mfma_f32_4x4x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x2bf16", I16, F32>;
-def ROCDL_mfma_f32_32x32x4bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4bf16", I16, F32>;
-def ROCDL_mfma_f32_16x16x8bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x8bf16", I16, F32>;
+def ROCDL_mfma_f32_32x32x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x1f32", /*Type AB=*/F32, /*Type CD=*/ROCDL_ConcreteVector<F32, 32>>;
+def ROCDL_mfma_f32_16x16x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x1f32", F32, ROCDL_ConcreteVector<F32, 16>>;
+def ROCDL_mfma_f32_4x4x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x1f32", F32, ROCDL_ConcreteVector<F32, 4>>;
+def ROCDL_mfma_f32_32x32x2f32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x2f32", F32, ROCDL_ConcreteVector<F32, 16>>;
+def ROCDL_mfma_f32_16x16x4f32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x4f32", F32, ROCDL_ConcreteVector<F32, 4>>;
+def ROCDL_mfma_f32_32x32x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4f16", ROCDL_ConcreteVector<F16, 4>, ROCDL_ConcreteVector<F32, 32>>;
+def ROCDL_mfma_f32_16x16x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x4f16", ROCDL_ConcreteVector<F16, 4>, ROCDL_ConcreteVector<F32, 16>>;
+def ROCDL_mfma_f32_4x4x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x4f16", ROCDL_ConcreteVector<F16, 4>, ROCDL_ConcreteVector<F32, 4>>;
+def ROCDL_mfma_f32_32x32x8f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x8f16", ROCDL_ConcreteVector<F16, 4>, ROCDL_ConcreteVector<F32, 16>>;
+def ROCDL_mfma_f32_16x16x16f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x16f16", ROCDL_ConcreteVector<F16, 4>, ROCDL_ConcreteVector<F32, 4>>;
+def ROCDL_mfma_i32_32x32x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x4i8", I32, ROCDL_ConcreteVector<I32, 32>>;
+def ROCDL_mfma_i32_16x16x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x4i8", I32, ROCDL_ConcreteVector<I32, 16>>;
+def ROCDL_mfma_i32_4x4x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.4x4x4i8", I32, ROCDL_ConcreteVector<I32, 4>>;
+def ROCDL_mfma_i32_32x32x8i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x8i8", I32, ROCDL_ConcreteVector<I32, 16>>;
+def ROCDL_mfma_i32_16x16x16i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x16i8", I32, ROCDL_ConcreteVector<I32, 4>>;
+def ROCDL_mfma_f32_32x32x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x2bf16", ROCDL_ConcreteVector<I16, 2>, ROCDL_ConcreteVector<F32, 32>>;
+def ROCDL_mfma_f32_16x16x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x2bf16", ROCDL_ConcreteVector<I16, 2>, ROCDL_ConcreteVector<F32, 16>>;
+def ROCDL_mfma_f32_4x4x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x2bf16", ROCDL_ConcreteVector<I16, 2>, ROCDL_ConcreteVector<F32, 4>>;
+def ROCDL_mfma_f32_32x32x4bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4bf16", ROCDL_ConcreteVector<I16, 2>, ROCDL_ConcreteVector<F32, 16>>;
+def ROCDL_mfma_f32_16x16x8bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x8bf16", ROCDL_ConcreteVector<I16, 2>, ROCDL_ConcreteVector<F32, 4>>;
// New in gfx90a.
-def ROCDL_mfma_f32_32x32x4bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4bf16.1k", I16, F32>;
-def ROCDL_mfma_f32_16x16x4bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x4bf16.1k", I16, F32>;
-def ROCDL_mfma_f32_4x4x4bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x4bf16.1k", I16, F32>;
-def ROCDL_mfma_f32_32x32x8bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x8bf16.1k", I16, F32>;
-def ROCDL_mfma_f32_16x16x16bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x16bf16.1k", I16, F32>;
+def ROCDL_mfma_f32_32x32x4bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4bf16.1k", ROCDL_ConcreteVector<I16, 4>, ROCDL_ConcreteVector<F32, 32>>;
+def ROCDL_mfma_f32_16x16x4bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x4bf16.1k", ROCDL_ConcreteVector<I16, 4>, ROCDL_ConcreteVector<F32, 16>>;
+def ROCDL_mfma_f32_4x4x4bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x4bf16.1k", ROCDL_ConcreteVector<I16, 4>, ROCDL_ConcreteVector<F32, 4>>;
+def ROCDL_mfma_f32_32x32x8bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x8bf16.1k", ROCDL_ConcreteVector<I16, 4>, ROCDL_ConcreteVector<F32, 16>>;
+def ROCDL_mfma_f32_16x16x16bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x16bf16.1k", ROCDL_ConcreteVector<I16, 4>, ROCDL_ConcreteVector<F32, 4>>;
// Note: in gfx94x, unlike in gfx90a, the f64 xdlops use the "blgp" argument as
// a NEG bitfield. See IntrinsicsAMDGPU.td for more info.
-def ROCDL_mfma_f64_16x16x4f64 : ROCDL_Mfma_IntrOp<"mfma.f64.16x16x4f64", F64, F64>;
+def ROCDL_mfma_f64_16x16x4f64 : ROCDL_Mfma_IntrOp<"mfma.f64.16x16x4f64", F64, ROCDL_ConcreteVector<F64, 4>>;
def ROCDL_mfma_f64_4x4x4f64 : ROCDL_Mfma_IntrOp<"mfma.f64.4x4x4f64", F64, F64>;
// New in gfx94x.
-def ROCDL_mfma_i32_16x16x32_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x32.i8", I64, I32>;
-def ROCDL_mfma_i32_32x32x16_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x16.i8", I64, I32>;
-def ROCDL_mfma_f32_16x16x8_xf32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x8.xf32", F32, F32>;
-def ROCDL_mfma_f32_32x32x4_xf32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4.xf32", F32, F32>;
-def ROCDL_mfma_f32_16x16x32_bf8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.bf8.bf8", I64, F32>;
-def ROCDL_mfma_f32_16x16x32_bf8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.bf8.fp8", I64, F32>;
-def ROCDL_mfma_f32_16x16x32_fp8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.fp8.bf8", I64, F32>;
-def ROCDL_mfma_f32_16x16x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.fp8.fp8", I64, F32>;
-def ROCDL_mfma_f32_32x32x16_bf8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf8.bf8", I64, F32>;
-def ROCDL_mfma_f32_32x32x16_bf8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf8.fp8", I64, F32>;
-def ROCDL_mfma_f32_32x32x16_fp8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.fp8.bf8", I64, F32>;
-def ROCDL_mfma_f32_32x32x16_fp8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.fp8.fp8", I64, F32>;
+def ROCDL_mfma_i32_16x16x32_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x32.i8", I64, ROCDL_ConcreteVector<I32, 4>>;
+def ROCDL_mfma_i32_32x32x16_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x16.i8", I64, ROCDL_ConcreteVector<I32, 16>>;
+def ROCDL_mfma_f32_16x16x8_xf32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x8.xf32", ROCDL_ConcreteVector<F32, 2>, ROCDL_ConcreteVector<F32, 4>>;
+def ROCDL_mfma_f32_32x32x4_xf32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4.xf32", ROCDL_ConcreteVector<F32, 2>, ROCDL_ConcreteVector<F32, 16>>;
+def ROCDL_mfma_f32_16x16x32_bf8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.bf8.bf8", I64, ROCDL_ConcreteVector<F32, 4>>;
+def ROCDL_mfma_f32_16x16x32_bf8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.bf8.fp8", I64, ROCDL_ConcreteVector<F32, 4>>;
+def ROCDL_mfma_f32_16x16x32_fp8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.fp8.bf8", I64, ROCDL_ConcreteVector<F32, 4>>;
+def ROCDL_mfma_f32_16x16x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.fp8.fp8", I64, ROCDL_ConcreteVector<F32, 4>>;
+def ROCDL_mfma_f32_32x32x16_bf8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf8.bf8", I64, ROCDL_ConcreteVector<F32, 16>>;
+def ROCDL_mfma_f32_32x32x16_bf8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf8.fp8", I64, ROCDL_ConcreteVector<F32, 16>>;
+def ROCDL_mfma_f32_32x32x16_fp8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.fp8.bf8", I64, ROCDL_ConcreteVector<F32, 16>>;
+def ROCDL_mfma_f32_32x32x16_fp8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.fp8.fp8", I64, ROCDL_ConcreteVector<F32, 16>>;
// New in gfx950.
-def ROCDL_mfma_f32_16x16x32_bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.bf16", BF16, F32>;
-def ROCDL_mfma_i32_16x16x64_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x64.i8", I32, I32>;
-def ROCDL_mfma_f32_16x16x32_f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.f16", F16, F32>;
-def ROCDL_mfma_f32_32x32x16_bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf16", BF16, F32>;
-def ROCDL_mfma_i32_32x32x32_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x32.i8", I32, I32>;
-def ROCDL_mfma_f32_32x32x16_f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.f16", F16, F32>;
+def ROCDL_mfma_f32_16x16x32_bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.bf16", ROCDL_ConcreteVector<BF16, 8>, ROCDL_ConcreteVector<F32, 4>>;
+def ROCDL_mfma_i32_16x16x64_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x64.i8", ROCDL_ConcreteVector<I32, 4>, ROCDL_ConcreteVector<I32, 4>>;
+def ROCDL_mfma_f32_16x16x32_f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.f16", ROCDL_ConcreteVector<F16, 8>, ROCDL_ConcreteVector<F32, 4>>;
+def ROCDL_mfma_f32_32x32x16_bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf16", ROCDL_ConcreteVector<BF16, 8>, ROCDL_ConcreteVector<F32, 16>>;
+def ROCDL_mfma_i32_32x32x32_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x32.i8", ROCDL_ConcreteVector<I32, 4>, ROCDL_ConcreteVector<I32, 16>>;
+def ROCDL_mfma_f32_32x32x16_f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.f16", ROCDL_ConcreteVector<F16, 8>, ROCDL_ConcreteVector<F32, 16>>;
def ROCDL_mfma_scale_f32_16x16x128_f8f6f4 : ROCDL_Mfma_Scale_IntrOp<"mfma.scale.f32.16x16x128.f8f6f4", I32, F32>;
def ROCDL_mfma_scale_f32_32x32x64_f8f6f4 : ROCDL_Mfma_Scale_IntrOp<"mfma.scale.f32.32x32x64.f8f6f4", I32, F32>;
// 2:4 Sparsity ops (GFX94x)
-def ROCDL_smfmac_f32_16x16x32_f16 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x32.f16", F16, F32>;
-def ROCDL_smfmac_f32_32x32x16_f16 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x16.f16", F16, F32>;
-def ROCDL_smfmac_f32_16x16x32_bf16 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x32.bf16", I16, F32>;
-def ROCDL_smfmac_f32_32x32x16_bf16 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x16.bf16", I16, F32>;
-def ROCDL_smfmac_i32_16x16x64_i8 : ROCDL_Smfmac_IntrOp<"smfmac.i32.16x16x64.i8", I32, I32>;
-def ROCDL_smfmac_i32_32x32x32_i8 : ROCDL_Smfmac_IntrOp<"smfmac.i32.32x32x32.i8", I32, I32>;
-def ROCDL_smfmac_f32_16x16x64_bf8_bf8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x64.bf8.bf8", I32, F32>;
-def ROCDL_smfmac_f32_16x16x64_bf8_fp8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x64.bf8.fp8", I32, F32>;
-def ROCDL_smfmac_f32_16x16x64_fp8_bf8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x64.fp8.bf8", I32, F32>;
-def ROCDL_smfmac_f32_16x16x64_fp8_fp8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x64.fp8.fp8", I32, F32>;
-def ROCDL_smfmac_f32_32x32x32_bf8_bf8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x32.bf8.bf8", I32, F32>;
-def ROCDL_smfmac_f32_32x32x32_bf8_fp8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x32.bf8.fp8", I32, F32>;
-def ROCDL_smfmac_f32_32x32x32_fp8_bf8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x32.fp8.bf8", I32, F32>;
-def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x32.fp8.fp8", I32, F32>;
+def ROCDL_smfmac_f32_16x16x32_f16 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x32.f16", ROCDL_ConcreteVector<F16, 4>, ROCDL_ConcreteVector<F16, 8>, ROCDL_ConcreteVector<F32, 4>>;
+def ROCDL_smfmac_f32_32x32x16_f16 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x16.f16", ROCDL_ConcreteVector<F16, 4>, ROCDL_ConcreteVector<F16, 8>, ROCDL_ConcreteVector<F32, 16>>;
+def ROCDL_smfmac_f32_16x16x32_bf16 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x32.bf16", ROCDL_ConcreteVector<I16, 4>, ROCDL_ConcreteVector<I16, 8>, ROCDL_ConcreteVector<F32, 4>>;
+def ROCDL_smfmac_f32_32x32x16_bf16 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x16.bf16", ROCDL_ConcreteVector<I16, 4>, ROCDL_ConcreteVector<I16, 8>, ROCDL_ConcreteVector<F32, 16>>;
+def ROCDL_smfmac_i32_16x16x64_i8 : ROCDL_Smfmac_IntrOp<"smfmac.i32.16x16x64.i8", ROCDL_ConcreteVector<I32, 2>, ROCDL_ConcreteVector<I32, 4>, ROCDL_ConcreteVector<I32, 4>>;
+def ROCDL_smfmac_i32_32x32x32_i8 : ROCDL_Smfmac_IntrOp<"smfmac.i32.32x32x32.i8", ROCDL_ConcreteVector<I32, 2>, ROCDL_ConcreteVector<I32, 4>, ROCDL_ConcreteVector<I32, 16>>;
+def ROCDL_smfmac_f32_16x16x64_bf8_bf8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x64.bf8.bf8", ROCDL_ConcreteVector<I32, 2>, ROCDL_ConcreteVector<I32, 4>, ROCDL_ConcreteVector<F32, 4>>;
+def ROCDL_smfmac_f32_16x16x64_bf8_fp8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x64.bf8.fp8", ROCDL_ConcreteVector<I32, 2>, ROCDL_ConcreteVector<I32, 4>, ROCDL_ConcreteVector<F32, 4>>;
+def ROCDL_smfmac_f32_16x16x64_fp8_bf8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x64.fp8.bf8", ROCDL_ConcreteVector<I32, 2>, ROCDL_ConcreteVector<I32, 4>, ROCDL_ConcreteVector<F32, 4>>;
+def ROCDL_smfmac_f32_16x16x64_fp8_fp8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x64.fp8.fp8", ROCDL_ConcreteVector<I32, 2>, ROCDL_ConcreteVector<I32, 4>, ROCDL_ConcreteVector<F32, 4>>;
+def ROCDL_smfmac_f32_32x32x32_bf8_bf8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x32.bf8.bf8", ROCDL_ConcreteVector<I32, 2>, ROCDL_ConcreteVector<I32, 4>, ROCDL_ConcreteVector<F32, 16>>;
+def ROCDL_smfmac_f32_32x32x32_bf8_fp8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x32.bf8.fp8", ROCDL_ConcreteVector<I32, 2>, ROCDL_ConcreteVector<I32, 4>, ROCDL_ConcreteVector<F32, 16>>;
+def ROCDL_smfmac_f32_32x32x32_fp8_bf8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x32.fp8.bf8", ROCDL_ConcreteVector<I32, 2>, ROCDL_ConcreteVector<I32, 4>, ROCDL_ConcreteVector<F32, 16>>;
+def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x32.fp8.fp8", ROCDL_ConcreteVector<I32, 2>, ROCDL_ConcreteVector<I32, 4>, ROCDL_ConcreteVector<F32, 16>>;
// New in gfx950.
-def ROCDL_smfmac_f32_16x16x64_bf16 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x64.bf16", BF16, F32>;
-def ROCDL_smfmac_f32_16x16x64_f16 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x64.f16", F16, F32>;
-def ROCDL_smfmac_i32_16x16x128_i8 : ROCDL_Smfmac_IntrOp<"smfmac.i32.16x16x128.i8", I32, I32>;
-def ROCDL_smfmac_f32_16x16x128_bf8_bf8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x128.bf8.bf8", I32, F32>;
-def ROCDL_smfmac_f32_16x16x128_bf8_fp8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x128.bf8.fp8", I32, F32>;
-def ROCDL_smfmac_f32_16x16x128_fp8_bf8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x128.fp8.bf8", I32, F32>;
-def ROCDL_smfmac_f32_16x16x128_fp8_fp8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x128.fp8.fp8", I32, F32>;
-def ROCDL_smfmac_f32_32x32x32_bf16 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x32.bf16", BF16, F32>;
-def ROCDL_smfmac_f32_32x32x32_f16 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x32.f16", F16, F32>;
-def ROCDL_smfmac_i32_32x32x64_i8 : ROCDL_Smfmac_IntrOp<"smfmac.i32.32x32x64.i8", I32, I32>;
-def ROCDL_smfmac_f32_32x32x64_bf8_bf8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x64.bf8.bf8", I32, F32>;
-def ROCDL_smfmac_f32_32x32x64_bf8_fp8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x64.bf8.fp8", I32, F32>;
-def ROCDL_smfmac_f32_32x32x64_fp8_bf8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x64.fp8.bf8", I32, F32>;
-def ROCDL_smfmac_f32_32x32x64_fp8_fp8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x64.fp8.fp8", I32, F32>;
+def ROCDL_smfmac_f32_16x16x64_bf16 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x64.bf16", ROCDL_ConcreteVector<BF16, 8>, ROCDL_ConcreteVector<BF16, 16>, ROCDL_ConcreteVector<F32, 4>>;
+def ROCDL_smfmac_f32_16x16x64_f16 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x64.f16", ROCDL_ConcreteVector<F16, 8>, ROCDL_ConcreteVector<F16, 16>, ROCDL_ConcreteVector<F32, 4>>;
+def ROCDL_smfmac_i32_16x16x128_i8 : ROCDL_Smfmac_IntrOp<"smfmac.i32.16x16x128.i8", ROCDL_ConcreteVector<I32, 4>, ROCDL_ConcreteVector<I32, 8>, ROCDL_ConcreteVector<I32, 4>>;
+def ROCDL_smfmac_f32_16x16x128_bf8_bf8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x128.bf8.bf8", ROCDL_ConcreteVector<I32, 4>, ROCDL_ConcreteVector<I32, 8>, ROCDL_ConcreteVector<F32, 4>>;
+def ROCDL_smfmac_f32_16x16x128_bf8_fp8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x128.bf8.fp8", ROCDL_ConcreteVector<I32, 4>, ROCDL_ConcreteVector<I32, 8>, ROCDL_ConcreteVector<F32, 4>>;
+def ROCDL_smfmac_f32_16x16x128_fp8_bf8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x128.fp8.bf8", ROCDL_ConcreteVector<I32, 4>, ROCDL_ConcreteVector<I32, 8>, ROCDL_ConcreteVector<F32, 4>>;
+def ROCDL_smfmac_f32_16x16x128_fp8_fp8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.16x16x128.fp8.fp8", ROCDL_ConcreteVector<I32, 4>, ROCDL_ConcreteVector<I32, 8>, ROCDL_ConcreteVector<F32, 4>>;
+def ROCDL_smfmac_f32_32x32x32_bf16 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x32.bf16", ROCDL_ConcreteVector<BF16, 8>, ROCDL_ConcreteVector<BF16, 16>, ROCDL_ConcreteVector<F32, 16>>;
+def ROCDL_smfmac_f32_32x32x32_f16 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x32.f16", ROCDL_ConcreteVector<F16, 8>, ROCDL_ConcreteVector<F16, 16>, ROCDL_ConcreteVector<F32, 16>>;
+def ROCDL_smfmac_i32_32x32x64_i8 : ROCDL_Smfmac_IntrOp<"smfmac.i32.32x32x64.i8", ROCDL_ConcreteVector<I32, 4>, ROCDL_ConcreteVector<I32, 8>, ROCDL_ConcreteVector<I32, 16>>;
+def ROCDL_smfmac_f32_32x32x64_bf8_bf8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x64.bf8.bf8", ROCDL_ConcreteVector<I32, 4>, ROCDL_ConcreteVector<I32, 8>, ROCDL_ConcreteVector<F32, 16>>;
+def ROCDL_smfmac_f32_32x32x64_bf8_fp8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x64.bf8.fp8", ROCDL_ConcreteVector<I32, 4>, ROCDL_ConcreteVector<I32, 8>, ROCDL_ConcreteVector<F32, 16>>;
+def ROCDL_smfmac_f32_32x32x64_fp8_bf8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x64.fp8.bf8", ROCDL_ConcreteVector<I32, 4>, ROCDL_ConcreteVector<I32, 8>, ROCDL_ConcreteVector<F32, 16>>;
+def ROCDL_smfmac_f32_32x32x64_fp8_fp8 : ROCDL_Smfmac_IntrOp<"smfmac.f32.32x32x64.fp8.fp8", ROCDL_ConcreteVector<I32, 4>, ROCDL_ConcreteVector<I32, 8>, ROCDL_ConcreteVector<F32, 16>>;
//===---------------------------------------------------------------------===//
>From ce5f02c826373013546e66698d95030e51ab1b8c Mon Sep 17 00:00:00 2001
From: ravil-mobile <ravil.aviva.com at gmail.com>
Date: Wed, 14 Jan 2026 14:07:44 +0000
Subject: [PATCH 6/6] [ROCDL] Fixed lowering from amdgpu to rocdl regarding
mfma ops
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 35 ++++-----
.../Conversion/AMDGPUToROCDL/mfma-gfx950.mlir | 68 ++++++++--------
mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir | 78 +++++++++----------
.../AMDGPUToROCDL/sparse-mfma-gfx950.mlir | 28 +++----
.../Conversion/AMDGPUToROCDL/sparse-mfma.mlir | 28 +++----
5 files changed, 115 insertions(+), 122 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 6427807e944a1..3f24f814a2143 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1402,14 +1402,15 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
if (isScaled) {
Value zero = createI32Constant(rewriter, loc, 0);
auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
- loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
- createI32Constant(rewriter, loc, bTypeCode),
- /*scale A byte=*/zero, /*scale A=*/zero,
- /*scale B byte=*/zero, /*scale B=*/zero});
+ loweredOp.addOperands({ /*scale A=*/zero, /*scale B=*/zero});
+ loweredOp.addAttributes({{"cbsz", rewriter.getI32IntegerAttr(aTypeCode)},
+ {"blgp", rewriter.getI32IntegerAttr(bTypeCode)},
+ {"opselA", rewriter.getI32IntegerAttr(0)},
+ {"opselB", rewriter.getI32IntegerAttr(0)}});
} else {
- loweredOp.addOperands({createI32Constant(rewriter, loc, op.getCbsz()),
- createI32Constant(rewriter, loc, op.getAbid()),
- createI32Constant(rewriter, loc, getBlgpField)});
+ loweredOp.addAttributes({{"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())},
+ {"abid", rewriter.getI32IntegerAttr(op.getAbid())},
+ {"blgp", rewriter.getI32IntegerAttr(getBlgpField)}});
};
Value lowered = rewriter.create(loweredOp)->getResult(0);
if (outType != intrinsicOutType)
@@ -1446,19 +1447,17 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
{packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA()),
packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB()),
adaptor.getDestC()});
- Value scalesIdxA =
- createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
- Value scalesIdxB =
- createI32Constant(rewriter, loc, adaptor.getScalesIdxB());
loweredOp.addOperands(
- {createI32Constant(rewriter, loc, aTypeCode),
- createI32Constant(rewriter, loc, bTypeCode),
- /*scales idx A=*/scalesIdxA,
+ {
/*scales A*/
castScaleOperand(rewriter, loc, adaptor.getScalesA()),
- /*scales idx B=*/scalesIdxB,
/*scales B*/
castScaleOperand(rewriter, loc, adaptor.getScalesB())});
+ loweredOp.addAttributes({{"cbsz", rewriter.getI32IntegerAttr(aTypeCode)},
+ {"blgp", rewriter.getI32IntegerAttr(bTypeCode)},
+ {"opselA", rewriter.getI32IntegerAttr(adaptor.getScalesIdxA())},
+ {"opselB", rewriter.getI32IntegerAttr(adaptor.getScalesIdxB())}});
+
Value lowered = rewriter.create(loweredOp)->getResult(0);
rewriter.replaceOp(op, lowered);
return success();
@@ -1502,9 +1501,9 @@ struct SparseMFMAOpLowering : public ConvertOpToLLVMPattern<SparseMFMAOp> {
OperationState loweredOp(loc, maybeIntrinsic.value());
loweredOp.addTypes(outType);
- loweredOp.addOperands({a, b, c, sparseIdx,
- createI32Constant(rewriter, loc, op.getCbsz()),
- createI32Constant(rewriter, loc, op.getAbid())});
+ loweredOp.addOperands({a, b, c, sparseIdx});
+ loweredOp.addAttributes({{"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())},
+ {"abid", rewriter.getI32IntegerAttr(op.getAbid())}});
Value lowered = rewriter.create(loweredOp)->getResult(0);
rewriter.replaceOp(op, lowered);
return success();
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
index c746d7690b00d..39d90ceb14436 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
@@ -5,53 +5,51 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
%arg6 : vector<4xi32>, %arg7 : vector<32xf8E4M3FN>,
%arg8 : vector<32xf8E5M2>, %arg9 : vector<32xf6E2M3FN>,
%arg10 : vector<32xf6E3M2FN>, %arg11 : vector<32xf4E2M1FN>) {
- // CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
- // CHECK: rocdl.mfma.f32.32x32x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.f32.32x32x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<16xf32>) -> vector<16xf32>
amdgpu.mfma 32x32x16 %arg0 * %arg0 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xf16>, vector<8xf16>, vector<16xf32>
- // CHECK: rocdl.mfma.f32.16x16x32.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.f32.16x16x32.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<4xf32>) -> vector<4xf32>
amdgpu.mfma 16x16x32 %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xf16>, vector<8xf16>, vector<4xf32>
- // CHECK: rocdl.mfma.f32.32x32x16.bf16{{.*}}: (vector<8xbf16>, vector<8xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.f32.32x32x16.bf16{{.*}}: (vector<8xbf16>, vector<8xbf16>, vector<16xf32>) -> vector<16xf32>
amdgpu.mfma 32x32x16 %arg3 * %arg3 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xbf16>, vector<8xbf16>, vector<16xf32>
- // CHECK: rocdl.mfma.f32.16x16x32.bf16{{.*}}: (vector<8xbf16>, vector<8xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.f32.16x16x32.bf16{{.*}}: (vector<8xbf16>, vector<8xbf16>, vector<4xf32>) -> vector<4xf32>
amdgpu.mfma 16x16x32 %arg3 * %arg3 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xbf16>, vector<8xbf16>, vector<4xf32>
- // CHECK: rocdl.mfma.i32.32x32x32.i8{{.*}}: (vector<4xi32>, vector<4xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
+ // CHECK: rocdl.mfma.i32.32x32x32.i8{{.*}}: (vector<4xi32>, vector<4xi32>, vector<16xi32>) -> vector<16xi32>
amdgpu.mfma 32x32x32 %arg4 * %arg4 + %arg5 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<16xi8>, vector<16xi8>, vector<16xi32>
- // CHECK: rocdl.mfma.i32.16x16x64.i8{{.*}}: (vector<4xi32>, vector<4xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+ // CHECK: rocdl.mfma.i32.16x16x64.i8{{.*}}: (vector<4xi32>, vector<4xi32>, vector<4xi32>) -> vector<4xi32>
amdgpu.mfma 16x16x64 %arg4 * %arg4 + %arg6 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<16xi8>, vector<16xi8>, vector<4xi32>
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c0]], %[[c0]]{{.*}}: (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+ // CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 %{{.*}}, %{{.*}}, %{{.*}}, 0, 0, 0, %[[c0]], 0, %[[c0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
amdgpu.mfma 32x32x64 %arg7 * %arg7 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<16xf32>
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c0]], %[[c0]]{{.*}}: (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 %{{.*}}, %{{.*}}, %{{.*}}, 0, 0, 0, %[[c0]], 0, %[[c0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
amdgpu.mfma 16x16x128 %arg7 * %arg7 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<4xf32>
- // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c0]], %[[c0]]{{.*}}: (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 %{{.*}}, %{{.*}}, %{{.*}}, 1, 1, 0, %[[c0]], 0, %[[c0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
amdgpu.mfma 32x32x64 %arg8 * %arg8 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<16xf32>
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c0]], %[[c0]]{{.*}}: (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 %{{.*}}, %{{.*}}, %{{.*}}, 1, 1, 0, %[[c0]], 0, %[[c0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
amdgpu.mfma 16x16x128 %arg8 * %arg8 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<4xf32>
- // CHECK: %[[c2:.+]] = llvm.mlir.constant(2 : i32) : i32
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 %{{.*}}, %{{.*}}, %{{.*}}, 2, 2, 0, %[[c0]], 0, %[[c0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
amdgpu.mfma 32x32x64 %arg9 * %arg9 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<16xf32>
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 %{{.*}}, %{{.*}}, %{{.*}}, 2, 2, 0, %[[c0]], 0, %[[c0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
amdgpu.mfma 16x16x128 %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<4xf32>
- // CHECK: %[[c3:.+]] = llvm.mlir.constant(3 : i32) : i32
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c3]], %[[c3]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 %{{.*}}, %{{.*}}, %{{.*}}, 3, 3, 0, %[[c0]], 0, %[[c0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
amdgpu.mfma 32x32x64 %arg10 * %arg10 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<16xf32>
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c3]], %[[c3]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 %{{.*}}, %{{.*}}, %{{.*}}, 3, 3, 0, %[[c0]], 0, %[[c0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
amdgpu.mfma 16x16x128 %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<4xf32>
- // CHECK-DAG: %[[c4:.+]] = llvm.mlir.constant(4 : i32) : i32
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c4]], %[[c4]], %[[c0]], %[[c0]]{{.*}}: (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 %{{.*}}, %{{.*}}, %{{.*}}, 4, 4, 0, %[[c0]], 0, %[[c0]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
amdgpu.mfma 32x32x64 %arg11 * %arg11 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<16xf32>
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c4]], %[[c4]], %[[c0]], %[[c0]]{{.*}}: (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 %{{.*}}, %{{.*}}, %{{.*}}, 4, 4, 0, %[[c0]], 0, %[[c0]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
amdgpu.mfma 16x16x128 %arg11 * %arg11 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<4xf32>
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c2]], %[[c4]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 %{{.*}}, %{{.*}}, %{{.*}}, 2, 4, 0, %[[c0]], 0, %[[c0]] : (vector<6xi32>, vector<4xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
amdgpu.mfma 32x32x64 %arg9 * %arg11 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf4E2M1FN>, vector<16xf32>
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c2]], %[[c4]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 %{{.*}}, %{{.*}}, %{{.*}}, 2, 4, 0, %[[c0]], 0, %[[c0]] : (vector<6xi32>, vector<4xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
amdgpu.mfma 16x16x128 %arg9 * %arg11 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf4E2M1FN>, vector<4xf32>
func.return
}
+
// CHECK-LABEL: func @scaled_mfma_to_rocdl(
// CHECK-SAME: %[[ARG0:.*]]: vector<16xf32>, %[[ARG1:.*]]: vector<4xf32>, %[[ARG2:.*]]: vector<32xf8E4M3FN>, %[[ARG3:.*]]: vector<32xf8E5M2>, %[[ARG4:.*]]: vector<32xf6E2M3FN>, %[[ARG5:.*]]: vector<32xf6E3M2FN>, %[[ARG6:.*]]: vector<32xf4E2M1FN>, %[[ARG7:.*]]: vector<4xf8E8M0FNU>, %[[ARG8:.*]]: f8E8M0FNU
func.func @scaled_mfma_to_rocdl(%arg0 : vector<16xf32>,
@@ -60,44 +58,40 @@ func.func @scaled_mfma_to_rocdl(%arg0 : vector<16xf32>,
%arg5 : vector<32xf6E3M2FN>, %arg6 : vector<32xf4E2M1FN>,
%arg7 : vector<4xf8E8M0FNU>, %arg8 : f8E8M0FNU) {
- // CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
- // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[b0:.+]] = llvm.bitcast {{.*}} : vector<4xi8> to i32
// CHECK: %[[z0:.+]] = llvm.zext {{.*}} : i8 to i32
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}}, {{.*}}, {{.*}}, 0, 0, 0, %[[b0]], 1, %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
amdgpu.scaled_mfma 32x32x64 (%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg0 : vector<4xf8E8M0FNU>, vector<32xf8E4M3FN>, f8E8M0FNU, vector<32xf8E4M3FN>, vector<16xf32>
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, {{.*}}, 0, 0, 0, %[[b0]], 1, %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
amdgpu.scaled_mfma 16x16x128 (%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg1 : vector<4xf8E8M0FNU>, vector<32xf8E4M3FN>, f8E8M0FNU, vector<32xf8E4M3FN>, vector<4xf32>
// CHECK: llvm.bitcast
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}}, {{.*}}, {{.*}}, 1, 1, 0, %[[b0]], 1, %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
amdgpu.scaled_mfma 32x32x64 (%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg0 : vector<4xf8E8M0FNU>, vector<32xf8E5M2>, f8E8M0FNU, vector<32xf8E5M2>, vector<16xf32>
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, {{.*}}, 1, 1, 0, %[[b0]], 1, %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
amdgpu.scaled_mfma 16x16x128 (%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg1 : vector<4xf8E8M0FNU>, vector<32xf8E5M2>, f8E8M0FNU, vector<32xf8E5M2>, vector<4xf32>
// CHECK: llvm.bitcast
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}}, {{.*}}, {{.*}}, 2, 2, 0, %[[b0]], 1, %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
amdgpu.scaled_mfma 32x32x64 (%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg0 : vector<4xf8E8M0FNU>, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<16xf32>
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, {{.*}}, 2, 2, 0, %[[b0]], 1, %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
amdgpu.scaled_mfma 16x16x128 (%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg1 : vector<4xf8E8M0FNU>, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<4xf32>
// CHECK: llvm.bitcast
- // CHECK: llvm.mlir.constant(3 : i32) : i32
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}}, {{.*}}, {{.*}}, 3, 3, 0, %[[b0]], 1, %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
amdgpu.scaled_mfma 32x32x64 (%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg0 : vector<4xf8E8M0FNU>, vector<32xf6E3M2FN>, f8E8M0FNU, vector<32xf6E3M2FN>, vector<16xf32>
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, {{.*}}, 3, 3, 0, %[[b0]], 1, %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
amdgpu.scaled_mfma 16x16x128 (%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg1 : vector<4xf8E8M0FNU>, vector<32xf6E3M2FN>, f8E8M0FNU, vector<32xf6E3M2FN>, vector<4xf32>
// CHECK: llvm.bitcast
- // CHECK: llvm.mlir.constant(4 : i32) : i32
- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 {{.*}}, {{.*}}, {{.*}}, 4, 4, 0, %[[b0]], 1, %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
amdgpu.scaled_mfma 32x32x64 (%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg0 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<16xf32>
- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, {{.*}}, 4, 4, 0, %[[b0]], 1, %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
amdgpu.scaled_mfma 16x16x128 (%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg1 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32>
func.return
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir
index e292d98183cd5..8d2f38794de5b 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir
@@ -8,89 +8,89 @@ func.func @mfma_to_rocdl(%arg0 : f32, %arg1 : vector<32xf32>,
%arg12 : vector<4xf64>, %arg13 : vector<8xi8>,
%arg14 : vector<2xf32>, %arg15 : vector<8xf8E5M2FNUZ>,
%arg16 : vector<8xf8E4M3FNUZ>) {
- // CHECK: rocdl.mfma.f32.32x32x1f32{{.*}}: (f32, f32, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
+ // CHECK: rocdl.mfma.f32.32x32x1f32{{.*}}: (f32, f32, vector<32xf32>) -> vector<32xf32>
amdgpu.mfma 32x32x1 %arg0 * %arg0 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, blocks = 2 : i32 } blgp = none : f32, f32, vector<32xf32>
- // CHECK: rocdl.mfma.f32.16x16x1f32{{.*}}: (f32, f32, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.f32.16x16x1f32{{.*}}: (f32, f32, vector<16xf32>) -> vector<16xf32>
amdgpu.mfma 16x16x1 %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, blocks = 4 : i32 } blgp = none : f32, f32, vector<16xf32>
- // CHECK: rocdl.mfma.f32.4x4x1f32{{.*}}: (f32, f32, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.f32.4x4x1f32{{.*}}: (f32, f32, vector<4xf32>) -> vector<4xf32>
amdgpu.mfma 4x4x1 %arg0 * %arg0 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, blocks = 16 : i32 } blgp = none : f32, f32, vector<4xf32>
- // CHECK: rocdl.mfma.f32.32x32x2f32{{.*}}: (f32, f32, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.f32.32x32x2f32{{.*}}: (f32, f32, vector<16xf32>) -> vector<16xf32>
amdgpu.mfma 32x32x2 %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : f32, f32, vector<16xf32>
- // CHECK: rocdl.mfma.f32.16x16x4f32{{.*}}: (f32, f32, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.f32.16x16x4f32{{.*}}: (f32, f32, vector<4xf32>) -> vector<4xf32>
amdgpu.mfma 16x16x4 %arg0 * %arg0 + %arg3 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : f32, f32, vector<4xf32>
- // CHECK: rocdl.mfma.f32.32x32x4f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
+ // CHECK: rocdl.mfma.f32.32x32x4f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<32xf32>) -> vector<32xf32>
amdgpu.mfma 32x32x4 %arg4 * %arg4 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, blocks = 2 : i32 } blgp = none : vector<4xf16>, vector<4xf16>, vector<32xf32>
- // CHECK: rocdl.mfma.f32.16x16x4f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.f32.16x16x4f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<16xf32>) -> vector<16xf32>
amdgpu.mfma 16x16x4 %arg4 * %arg4 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, blocks = 4 : i32 } blgp = none : vector<4xf16>, vector<4xf16>, vector<16xf32>
- // CHECK: rocdl.mfma.f32.4x4x4f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.f32.4x4x4f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf32>) -> vector<4xf32>
amdgpu.mfma 4x4x4 %arg4 * %arg4 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, blocks = 16 : i32 } blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
- // CHECK: rocdl.mfma.f32.32x32x8f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.f32.32x32x8f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<16xf32>) -> vector<16xf32>
amdgpu.mfma 32x32x8 %arg4 * %arg4 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<4xf16>, vector<4xf16>, vector<16xf32>
- // CHECK: rocdl.mfma.f32.16x16x16f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.f32.16x16x16f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf32>) -> vector<4xf32>
amdgpu.mfma 16x16x16 %arg4 * %arg4 + %arg3 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
// CHECK: %[[BITCAST_4xi8_i32:.+]] = llvm.bitcast {{.*}} : vector<4xi8> to i32
- // CHECK: rocdl.mfma.i32.32x32x4i8 %[[BITCAST_4xi8_i32]], %[[BITCAST_4xi8_i32]], {{.*}}: (i32, i32, vector<32xi32>, i32, i32, i32) -> vector<32xi32>
+ // CHECK: rocdl.mfma.i32.32x32x4i8 %[[BITCAST_4xi8_i32]], %[[BITCAST_4xi8_i32]], {{.*}}: (i32, i32, vector<32xi32>) -> vector<32xi32>
amdgpu.mfma 32x32x4 %arg5 * %arg5 + %arg6 { abid = 0 : i32, cbsz = 0 : i32, blocks = 2 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<32xi32>
- // CHECK: rocdl.mfma.i32.16x16x4i8{{.*}}: (i32, i32, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
+ // CHECK: rocdl.mfma.i32.16x16x4i8{{.*}}: (i32, i32, vector<16xi32>) -> vector<16xi32>
amdgpu.mfma 16x16x4 %arg5 * %arg5 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, blocks = 4 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<16xi32>
- // CHECK: rocdl.mfma.i32.4x4x4i8{{.*}}: (i32, i32, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+ // CHECK: rocdl.mfma.i32.4x4x4i8{{.*}}: (i32, i32, vector<4xi32>) -> vector<4xi32>
amdgpu.mfma 4x4x4 %arg5 * %arg5 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, blocks = 16 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<4xi32>
- // CHECK: rocdl.mfma.i32.32x32x8i8{{.*}}: (i32, i32, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
+ // CHECK: rocdl.mfma.i32.32x32x8i8{{.*}}: (i32, i32, vector<16xi32>) -> vector<16xi32>
amdgpu.mfma 32x32x8 %arg5 * %arg5 + %arg7 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<16xi32>
- // CHECK: rocdl.mfma.i32.16x16x16i8{{.*}}: (i32, i32, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+ // CHECK: rocdl.mfma.i32.16x16x16i8{{.*}}: (i32, i32, vector<4xi32>) -> vector<4xi32>
amdgpu.mfma 16x16x16 %arg5 * %arg5 + %arg8 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<4xi32>
// CHECK: %[[BITCAST_2xbf16_2xi16:.+]] = llvm.bitcast {{.*}} : vector<2xbf16> to vector<2xi16>
- // CHECK: rocdl.mfma.f32.32x32x2bf16 %[[BITCAST_2xbf16_2xi16]], %[[BITCAST_2xbf16_2xi16]], %{{.*}}: (vector<2xi16>, vector<2xi16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
+ // CHECK: rocdl.mfma.f32.32x32x2bf16 %[[BITCAST_2xbf16_2xi16]], %[[BITCAST_2xbf16_2xi16]], %{{.*}}: (vector<2xi16>, vector<2xi16>, vector<32xf32>) -> vector<32xf32>
amdgpu.mfma 32x32x2 %arg9 * %arg9 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, blocks = 2 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<32xf32>
- // CHECK: rocdl.mfma.f32.16x16x2bf16{{.*}}: (vector<2xi16>, vector<2xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.f32.16x16x2bf16{{.*}}: (vector<2xi16>, vector<2xi16>, vector<16xf32>) -> vector<16xf32>
amdgpu.mfma 16x16x2 %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, blocks = 4 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<16xf32>
- // CHECK: rocdl.mfma.f32.4x4x2bf16{{.*}}: (vector<2xi16>, vector<2xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.f32.4x4x2bf16{{.*}}: (vector<2xi16>, vector<2xi16>, vector<4xf32>) -> vector<4xf32>
amdgpu.mfma 4x4x2 %arg9 * %arg9 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, blocks = 16 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<4xf32>
- // CHECK: rocdl.mfma.f32.32x32x4bf16{{.*}}: (vector<2xi16>, vector<2xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.f32.32x32x4bf16{{.*}}: (vector<2xi16>, vector<2xi16>, vector<16xf32>) -> vector<16xf32>
amdgpu.mfma 32x32x4 %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<16xf32>
- // CHECK: rocdl.mfma.f32.16x16x8bf16{{.*}}: (vector<2xi16>, vector<2xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.f32.16x16x8bf16{{.*}}: (vector<2xi16>, vector<2xi16>, vector<4xf32>) -> vector<4xf32>
amdgpu.mfma 16x16x8 %arg9 * %arg9 + %arg3 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<4xf32>
// CHECK: %[[BITCAST_4xbf16_4xi16:.+]] = llvm.bitcast {{.*}} : vector<4xbf16> to vector<4xi16>
- // CHECK: rocdl.mfma.f32.32x32x4bf16.1k %[[BITCAST_4xbf16_4xi16]], %[[BITCAST_4xbf16_4xi16]], {{.*}}: (vector<4xi16>, vector<4xi16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
+ // CHECK: rocdl.mfma.f32.32x32x4bf16.1k %[[BITCAST_4xbf16_4xi16]], %[[BITCAST_4xbf16_4xi16]], {{.*}}: (vector<4xi16>, vector<4xi16>, vector<32xf32>) -> vector<32xf32>
amdgpu.mfma 32x32x4 %arg10 * %arg10 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, blocks = 2 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<32xf32>
- // CHECK: rocdl.mfma.f32.16x16x4bf16.1k{{.*}}: (vector<4xi16>, vector<4xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.f32.16x16x4bf16.1k{{.*}}: (vector<4xi16>, vector<4xi16>, vector<16xf32>) -> vector<16xf32>
amdgpu.mfma 16x16x4 %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, blocks = 4 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<16xf32>
- // CHECK: rocdl.mfma.f32.4x4x4bf16.1k{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.f32.4x4x4bf16.1k{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xf32>) -> vector<4xf32>
amdgpu.mfma 4x4x4 %arg10 * %arg10 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, blocks = 16 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<4xf32>
- // CHECK: rocdl.mfma.f32.32x32x8bf16.1k{{.*}}: (vector<4xi16>, vector<4xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.f32.32x32x8bf16.1k{{.*}}: (vector<4xi16>, vector<4xi16>, vector<16xf32>) -> vector<16xf32>
amdgpu.mfma 32x32x8 %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<16xf32>
- // CHECK: rocdl.mfma.f32.16x16x16bf16.1k{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.f32.16x16x16bf16.1k{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xf32>) -> vector<4xf32>
amdgpu.mfma 16x16x16 %arg10 * %arg10 + %arg3 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<4xf32>
- // CHECK: rocdl.mfma.f64.16x16x4f64{{.*}}: (f64, f64, vector<4xf64>, i32, i32, i32) -> vector<4xf64>
+ // CHECK: rocdl.mfma.f64.16x16x4f64{{.*}}: (f64, f64, vector<4xf64>) -> vector<4xf64>
amdgpu.mfma 16x16x4 %arg11 * %arg11 + %arg12 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : f64, f64, vector<4xf64>
- // CHECK: rocdl.mfma.f64.4x4x4f64{{.*}}: (f64, f64, f64, i32, i32, i32) -> f64
+ // CHECK: rocdl.mfma.f64.4x4x4f64{{.*}}: (f64, f64, f64) -> f64
amdgpu.mfma 4x4x4 %arg11 * %arg11 + %arg11 { abid = 0 : i32, cbsz = 0 : i32, blocks = 4 : i32 } blgp = none : f64, f64, f64
// CHECK: %[[BITCAST_8xi8_i64:.+]] = llvm.bitcast {{.*}} : vector<8xi8> to i64
- // CHECK: rocdl.mfma.i32.16x16x32.i8 %[[BITCAST_8xi8_i64]], %[[BITCAST_8xi8_i64]], {{.*}}: (i64, i64, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+ // CHECK: rocdl.mfma.i32.16x16x32.i8 %[[BITCAST_8xi8_i64]], %[[BITCAST_8xi8_i64]], {{.*}}: (i64, i64, vector<4xi32>) -> vector<4xi32>
amdgpu.mfma 16x16x32 %arg13 * %arg13 + %arg8 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xi8>, vector<8xi8>, vector<4xi32>
- // CHECK: rocdl.mfma.i32.32x32x16.i8{{.*}}: (i64, i64, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
+ // CHECK: rocdl.mfma.i32.32x32x16.i8{{.*}}: (i64, i64, vector<16xi32>) -> vector<16xi32>
amdgpu.mfma 32x32x16 %arg13 * %arg13 + %arg7 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xi8>, vector<8xi8>, vector<16xi32>
- // CHECK: rocdl.mfma.f32.16x16x8.xf32{{.*}}: (vector<2xf32>, vector<2xf32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.f32.16x16x8.xf32{{.*}}: (vector<2xf32>, vector<2xf32>, vector<4xf32>) -> vector<4xf32>
amdgpu.mfma 16x16x8 %arg14 * %arg14 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, reducePrecision } blgp = none : vector<2xf32>, vector<2xf32>, vector<4xf32>
- // CHECK: rocdl.mfma.f32.32x32x4.xf32{{.*}}: (vector<2xf32>, vector<2xf32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.f32.32x32x4.xf32{{.*}}: (vector<2xf32>, vector<2xf32>, vector<16xf32>) -> vector<16xf32>
amdgpu.mfma 32x32x4 %arg14 * %arg14 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, reducePrecision } blgp = none : vector<2xf32>, vector<2xf32>, vector<16xf32>
// CHECK: %[[BITCAST_8xi8_i64_1:.+]] = llvm.bitcast {{.*}} : vector<8xi8> to i64
- // CHECK: rocdl.mfma.f32.16x16x32.bf8.bf8 %[[BITCAST_8xi8_i64_1]], %[[BITCAST_8xi8_i64_1]], {{.*}}: (i64, i64, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.f32.16x16x32.bf8.bf8 %[[BITCAST_8xi8_i64_1]], %[[BITCAST_8xi8_i64_1]], {{.*}}: (i64, i64, vector<4xf32>) -> vector<4xf32>
amdgpu.mfma 16x16x32 %arg15 * %arg15 + %arg3 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xf8E5M2FNUZ>, vector<8xf8E5M2FNUZ>, vector<4xf32>
// CHECK: %[[BITCAST_8xi8_i64_2:.+]] = llvm.bitcast {{.*}} : vector<8xi8> to i64
- // CHECK: rocdl.mfma.f32.16x16x32.bf8.fp8 %[[BITCAST_8xi8_i64_1]], %[[BITCAST_8xi8_i64_2]], {{.*}}: (i64, i64, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.f32.16x16x32.bf8.fp8 %[[BITCAST_8xi8_i64_1]], %[[BITCAST_8xi8_i64_2]], {{.*}}: (i64, i64, vector<4xf32>) -> vector<4xf32>
amdgpu.mfma 16x16x32 %arg15 * %arg16 + %arg3 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xf8E5M2FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>
- // CHECK: rocdl.mfma.f32.16x16x32.fp8.bf8{{.*}}: (i64, i64, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.f32.16x16x32.fp8.bf8{{.*}}: (i64, i64, vector<4xf32>) -> vector<4xf32>
amdgpu.mfma 16x16x32 %arg16 * %arg15 + %arg3 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xf8E4M3FNUZ>, vector<8xf8E5M2FNUZ>, vector<4xf32>
- // CHECK: rocdl.mfma.f32.16x16x32.fp8.fp8{{.*}}: (i64, i64, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.mfma.f32.16x16x32.fp8.fp8{{.*}}: (i64, i64, vector<4xf32>) -> vector<4xf32>
amdgpu.mfma 16x16x32 %arg16 * %arg16 + %arg3 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>
- // CHECK: rocdl.mfma.f32.32x32x16.bf8.bf8{{.*}}: (i64, i64, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.f32.32x32x16.bf8.bf8{{.*}}: (i64, i64, vector<16xf32>) -> vector<16xf32>
amdgpu.mfma 32x32x16 %arg15 * %arg15 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xf8E5M2FNUZ>, vector<8xf8E5M2FNUZ>, vector<16xf32>
- // CHECK: rocdl.mfma.f32.32x32x16.bf8.fp8{{.*}}: (i64, i64, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.f32.32x32x16.bf8.fp8{{.*}}: (i64, i64, vector<16xf32>) -> vector<16xf32>
amdgpu.mfma 32x32x16 %arg15 * %arg16 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xf8E5M2FNUZ>, vector<8xf8E4M3FNUZ>, vector<16xf32>
- // CHECK: rocdl.mfma.f32.32x32x16.fp8.bf8{{.*}}: (i64, i64, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.f32.32x32x16.fp8.bf8{{.*}}: (i64, i64, vector<16xf32>) -> vector<16xf32>
amdgpu.mfma 32x32x16 %arg16 * %arg15 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xf8E4M3FNUZ>, vector<8xf8E5M2FNUZ>, vector<16xf32>
- // CHECK: rocdl.mfma.f32.32x32x16.fp8.fp8{{.*}}: (i64, i64, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.mfma.f32.32x32x16.fp8.fp8{{.*}}: (i64, i64, vector<16xf32>) -> vector<16xf32>
amdgpu.mfma 32x32x16 %arg16 * %arg16 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<16xf32>
func.return
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma-gfx950.mlir b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma-gfx950.mlir
index 266e0e7e15595..557c71c32ee54 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma-gfx950.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma-gfx950.mlir
@@ -8,53 +8,53 @@ func.func @sparse_mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf16>,
%arg12 : vector<32xf8E4M3FN>, %arg13 : vector<32xf8E5M2>,
%arg14 : vector<4xi8>, %arg15 : vector<2xi16>) {
// CHECK: llvm.bitcast %{{.*}} : vector<4xi8> to i32
- // CHECK: rocdl.smfmac.f32.16x16x64.f16{{.*}}: (vector<8xf16>, vector<16xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.smfmac.f32.16x16x64.f16{{.*}}: (vector<8xf16>, vector<16xf16>, vector<4xf32>, i32) -> vector<4xf32>
amdgpu.sparse_mfma 16x16x64 %arg0 * %arg1 + %arg2 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf16>, vector<16xf16>, vector<4xf32>
- // CHECK: rocdl.smfmac.f32.16x16x64.bf16{{.*}}: (vector<8xbf16>, vector<16xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.smfmac.f32.16x16x64.bf16{{.*}}: (vector<8xbf16>, vector<16xbf16>, vector<4xf32>, i32) -> vector<4xf32>
amdgpu.sparse_mfma 16x16x64 %arg4 * %arg5 + %arg2 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xbf16>, vector<16xbf16>, vector<4xf32>
// CHECK: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
// CHECK: llvm.bitcast {{.*}} : vector<32xi8> to vector<8xi32>
// CHECK: llvm.bitcast %{{.*}} : vector<2xi16> to i32
- // CHECK: rocdl.smfmac.i32.16x16x128.i8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+ // CHECK: rocdl.smfmac.i32.16x16x128.i8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xi32>, i32) -> vector<4xi32>
amdgpu.sparse_mfma 16x16x128 %arg6 * %arg7 + %arg8 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xi8>, vector<32xi8>, vector<4xi32>
// CHECK: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
// CHECK: llvm.bitcast {{.*}} : vector<32xi8> to vector<8xi32>
- // CHECK: rocdl.smfmac.f32.16x16x128.fp8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.smfmac.f32.16x16x128.fp8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32) -> vector<4xf32>
amdgpu.sparse_mfma 16x16x128 %arg10 * %arg12 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E4M3FN>, vector<4xf32>
// CHECK: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
// CHECK: llvm.bitcast {{.*}} : vector<32xi8> to vector<8xi32>
- // CHECK: rocdl.smfmac.f32.16x16x128.bf8.bf8 {{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.smfmac.f32.16x16x128.bf8.bf8 {{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32) -> vector<4xf32>
amdgpu.sparse_mfma 16x16x128 %arg11 * %arg13 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E5M2>, vector<4xf32>
- // CHECK: rocdl.smfmac.f32.16x16x128.fp8.bf8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.smfmac.f32.16x16x128.fp8.bf8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32) -> vector<4xf32>
amdgpu.sparse_mfma 16x16x128 %arg10 * %arg13 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E5M2>, vector<4xf32>
- // CHECK: rocdl.smfmac.f32.16x16x128.bf8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.smfmac.f32.16x16x128.bf8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32) -> vector<4xf32>
amdgpu.sparse_mfma 16x16x128 %arg11 * %arg12 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E4M3FN>, vector<4xf32>
- // CHECK: rocdl.smfmac.f32.32x32x32.f16{{.*}}: (vector<8xf16>, vector<16xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.smfmac.f32.32x32x32.f16{{.*}}: (vector<8xf16>, vector<16xf16>, vector<16xf32>, i32) -> vector<16xf32>
amdgpu.sparse_mfma 32x32x32 %arg0 * %arg1 + %arg3 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf16>, vector<16xf16>, vector<16xf32>
- // CHECK: rocdl.smfmac.f32.32x32x32.bf16{{.*}}: (vector<8xbf16>, vector<16xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.smfmac.f32.32x32x32.bf16{{.*}}: (vector<8xbf16>, vector<16xbf16>, vector<16xf32>, i32) -> vector<16xf32>
amdgpu.sparse_mfma 32x32x32 %arg4 * %arg5 + %arg3 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xbf16>, vector<16xbf16>, vector<16xf32>
- // CHECK: rocdl.smfmac.i32.32x32x64.i8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
+ // CHECK: rocdl.smfmac.i32.32x32x64.i8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xi32>, i32) -> vector<16xi32>
amdgpu.sparse_mfma 32x32x64 %arg6 * %arg7 + %arg9 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xi8>, vector<32xi8>, vector<16xi32>
- // CHECK: rocdl.smfmac.f32.32x32x64.fp8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.smfmac.f32.32x32x64.fp8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32) -> vector<16xf32>
amdgpu.sparse_mfma 32x32x64 %arg10 * %arg12 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E4M3FN>, vector<16xf32>
- // CHECK: rocdl.smfmac.f32.32x32x64.bf8.bf8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.smfmac.f32.32x32x64.bf8.bf8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32) -> vector<16xf32>
amdgpu.sparse_mfma 32x32x64 %arg11 * %arg13 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E5M2>, vector<16xf32>
- // CHECK: rocdl.smfmac.f32.32x32x64.fp8.bf8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.smfmac.f32.32x32x64.fp8.bf8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32) -> vector<16xf32>
amdgpu.sparse_mfma 32x32x64 %arg10 * %arg13 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E5M2>, vector<16xf32>
- // CHECK: rocdl.smfmac.f32.32x32x64.bf8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.smfmac.f32.32x32x64.bf8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32) -> vector<16xf32>
amdgpu.sparse_mfma 32x32x64 %arg11 * %arg12 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E4M3FN>, vector<16xf32>
func.return
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir
index b2c91c3d9bed1..1a6fef6cf72da 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir
@@ -8,55 +8,55 @@ func.func @sparse_mfma_to_rocdl(%arg0 : vector<4xf16>, %arg1 : vector<8xf16>,
%arg12 : vector<16xf8E4M3FNUZ>, %arg13 : vector<16xf8E5M2FNUZ>,
%arg14 : vector<4xi8>, %arg15 : vector<2xi16>) {
// CHECK: llvm.bitcast %{{.*}} : vector<4xi8> to i32
- // CHECK: rocdl.smfmac.f32.16x16x32.f16{{.*}}: (vector<4xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.smfmac.f32.16x16x32.f16{{.*}}: (vector<4xf16>, vector<8xf16>, vector<4xf32>, i32) -> vector<4xf32>
amdgpu.sparse_mfma 16x16x32 %arg0 * %arg1 + %arg2 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xf16>, vector<8xf16>, vector<4xf32>
// CHECK: llvm.bitcast {{.*}} : vector<4xbf16> to vector<4xi16>
// CHECK: llvm.bitcast {{.*}} : vector<8xbf16> to vector<8xi16>
- // CHECK: rocdl.smfmac.f32.16x16x32.bf16 {{.*}}: (vector<4xi16>, vector<8xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.smfmac.f32.16x16x32.bf16 {{.*}}: (vector<4xi16>, vector<8xi16>, vector<4xf32>, i32) -> vector<4xf32>
amdgpu.sparse_mfma 16x16x32 %arg4 * %arg5 + %arg2 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xbf16>, vector<8xbf16>, vector<4xf32>
- // CHECK: rocdl.smfmac.f32.32x32x16.f16{{.*}}: (vector<4xf16>, vector<8xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.smfmac.f32.32x32x16.f16{{.*}}: (vector<4xf16>, vector<8xf16>, vector<16xf32>, i32) -> vector<16xf32>
amdgpu.sparse_mfma 32x32x16 %arg0 * %arg1 + %arg3 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xf16>, vector<8xf16>, vector<16xf32>
- // CHECK: rocdl.smfmac.f32.32x32x16.bf16 {{.*}}: (vector<4xi16>, vector<8xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.smfmac.f32.32x32x16.bf16 {{.*}}: (vector<4xi16>, vector<8xi16>, vector<16xf32>, i32) -> vector<16xf32>
amdgpu.sparse_mfma 32x32x16 %arg4 * %arg5 + %arg3 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xbf16>, vector<8xbf16>, vector<16xf32>
// CHECK: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32>
// CHECK: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
// CHECK: llvm.bitcast %{{.*}} : vector<2xi16> to i32
- // CHECK: rocdl.smfmac.i32.16x16x64.i8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+ // CHECK: rocdl.smfmac.i32.16x16x64.i8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xi32>, i32) -> vector<4xi32>
amdgpu.sparse_mfma 16x16x64 %arg6 * %arg7 + %arg8 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xi8>, vector<16xi8>, vector<4xi32>
// CHECK: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32>
// CHECK: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
- // CHECK: rocdl.smfmac.f32.16x16x64.fp8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.smfmac.f32.16x16x64.fp8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32) -> vector<4xf32>
amdgpu.sparse_mfma 16x16x64 %arg10 * %arg12 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E4M3FNUZ>, vector<4xf32>
// CHECK: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32>
// CHECK: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
- // CHECK: rocdl.smfmac.f32.16x16x64.bf8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.smfmac.f32.16x16x64.bf8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32) -> vector<4xf32>
amdgpu.sparse_mfma 16x16x64 %arg11 * %arg13 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E5M2FNUZ>, vector<4xf32>
- // CHECK: rocdl.smfmac.f32.16x16x64.fp8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.smfmac.f32.16x16x64.fp8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32) -> vector<4xf32>
amdgpu.sparse_mfma 16x16x64 %arg10 * %arg13 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E5M2FNUZ>, vector<4xf32>
- // CHECK: rocdl.smfmac.f32.16x16x64.bf8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ // CHECK: rocdl.smfmac.f32.16x16x64.bf8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32) -> vector<4xf32>
amdgpu.sparse_mfma 16x16x64 %arg11 * %arg12 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E4M3FNUZ>, vector<4xf32>
- // CHECK: rocdl.smfmac.i32.32x32x32.i8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
+ // CHECK: rocdl.smfmac.i32.32x32x32.i8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xi32>, i32) -> vector<16xi32>
amdgpu.sparse_mfma 32x32x32 %arg6 * %arg7 + %arg9 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xi8>, vector<16xi8>, vector<16xi32>
- // CHECK: rocdl.smfmac.f32.32x32x32.fp8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.smfmac.f32.32x32x32.fp8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32) -> vector<16xf32>
amdgpu.sparse_mfma 32x32x32 %arg10 * %arg12 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E4M3FNUZ>, vector<16xf32>
- // CHECK: rocdl.smfmac.f32.32x32x32.bf8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.smfmac.f32.32x32x32.bf8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32) -> vector<16xf32>
amdgpu.sparse_mfma 32x32x32 %arg11 * %arg13 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E5M2FNUZ>, vector<16xf32>
- // CHECK: rocdl.smfmac.f32.32x32x32.fp8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.smfmac.f32.32x32x32.fp8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32) -> vector<16xf32>
amdgpu.sparse_mfma 32x32x32 %arg10 * %arg13 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E5M2FNUZ>, vector<16xf32>
- // CHECK: rocdl.smfmac.f32.32x32x32.bf8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.smfmac.f32.32x32x32.bf8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32) -> vector<16xf32>
amdgpu.sparse_mfma 32x32x32 %arg11 * %arg12 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E4M3FNUZ>, vector<16xf32>
func.return
More information about the Mlir-commits
mailing list