[Mlir-commits] [mlir] [mlir][rocdl] add gfx950 smfmac instructions to rocdl dialect (PR #171737)
Eric Feng
llvmlistbot at llvm.org
Wed Dec 10 15:50:11 PST 2025
https://github.com/efric updated https://github.com/llvm/llvm-project/pull/171737
>From e50fb1b99cfe5e1024dee010a553826d94a5d223 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Wed, 10 Dec 2025 14:33:51 -0800
Subject: [PATCH 1/2] add gfx950 smfmac to rocdl
Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 15 ++++
mlir/test/Dialect/LLVMIR/rocdl.mlir | 77 ++++++++++++++++++-
mlir/test/Target/LLVMIR/rocdl.mlir | 78 +++++++++++++++++++-
3 files changed, 167 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 0edb208a8fcba..fe8a854cd1321 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -592,6 +592,21 @@ def ROCDL_smfmac_f32_32x32x32_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.b
def ROCDL_smfmac_f32_32x32x32_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf8.fp8">;
def ROCDL_smfmac_f32_32x32x32_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.bf8">;
def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.fp8">;
+// New in gfx950.
+def ROCDL_smfmac_f32_16x16x64_bf16 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.bf16">;
+def ROCDL_smfmac_f32_16x16x64_f16 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.f16">;
+def ROCDL_smfmac_i32_16x16x128_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.16x16x128.i8">;
+def ROCDL_smfmac_f32_16x16x128_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.bf8.bf8">;
+def ROCDL_smfmac_f32_16x16x128_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.bf8.fp8">;
+def ROCDL_smfmac_f32_16x16x128_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.fp8.bf8">;
+def ROCDL_smfmac_f32_16x16x128_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.fp8.fp8">;
+def ROCDL_smfmac_f32_32x32x32_bf16 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf16">;
+def ROCDL_smfmac_f32_32x32x32_f16 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.f16">;
+def ROCDL_smfmac_i32_32x32x64_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.32x32x64.i8">;
+def ROCDL_smfmac_f32_32x32x64_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.bf8.bf8">;
+def ROCDL_smfmac_f32_32x32x64_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.bf8.fp8">;
+def ROCDL_smfmac_f32_32x32x64_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.fp8.bf8">;
+def ROCDL_smfmac_f32_32x32x64_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.fp8.fp8">;
//===---------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 1b50feea418b6..745fea8e38955 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -288,7 +288,12 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
%arg6 : vector<8 x i16>,
%arg7 : vector<2xi32>,
%arg8 : vector<4xi32>,
- %arg9 : vector<16xi32>) -> vector<4 x f32> {
+ %arg9 : vector<16xi32>,
+ %arg10 : vector<8 x f16>,
+ %arg11 : vector<16 x f16>,
+ %arg12 : vector<8 x bf16>,
+ %arg13 : vector<16 x bf16>,
+ %arg14 : vector<8 x i32>) -> vector<4 x f32> {
%csti32 = llvm.mlir.constant(42 : i32) : i32
// CHECK-LABEL: rocdl.smfmac
@@ -362,6 +367,76 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.smfmac.f32.16x16x64.f16 %{{.*}} : (vector<8xf16>, vector<16xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ %r14 = rocdl.smfmac.f32.16x16x64.f16 %arg10, %arg11, %arg3, %csti32, %csti32, %csti32 :
+ (vector<8xf16>, vector<16xf16>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x32.f16 %{{.*}} : (vector<8xf16>, vector<16xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ %r15 = rocdl.smfmac.f32.32x32x32.f16 %arg10, %arg11, %arg4, %csti32, %csti32, %csti32 :
+ (vector<8xf16>, vector<16xf16>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x64.bf16 %{{.*}} : (vector<8xbf16>, vector<16xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ %r16 = rocdl.smfmac.f32.16x16x64.bf16 %arg12, %arg13, %arg3, %csti32, %csti32, %csti32 :
+ (vector<8xbf16>, vector<16xbf16>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x32.bf16 %{{.*}} : (vector<8xbf16>, vector<16xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ %r17 = rocdl.smfmac.f32.32x32x32.bf16 %arg12, %arg13, %arg4, %csti32, %csti32, %csti32 :
+ (vector<8xbf16>, vector<16xbf16>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
+ // CHECK: rocdl.smfmac.i32.16x16x128.i8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+ %r18 = rocdl.smfmac.i32.16x16x128.i8 %arg8, %arg14, %arg8, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<4xi32>,
+ i32, i32, i32) -> vector<4xi32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x128.bf8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ %r19 = rocdl.smfmac.f32.16x16x128.bf8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x128.bf8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ %r20 = rocdl.smfmac.f32.16x16x128.bf8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x128.fp8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ %r21 = rocdl.smfmac.f32.16x16x128.fp8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x128.fp8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ %r22 = rocdl.smfmac.f32.16x16x128.fp8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: rocdl.smfmac.i32.32x32x64.i8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
+ %r23 = rocdl.smfmac.i32.32x32x64.i8 %arg8, %arg14, %arg9, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<16xi32>,
+ i32, i32, i32) -> vector<16xi32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x64.bf8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ %r24 = rocdl.smfmac.f32.32x32x64.bf8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x64.bf8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ %r25 = rocdl.smfmac.f32.32x32x64.bf8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x64.fp8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ %r26 = rocdl.smfmac.f32.32x32x64.fp8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x64.fp8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ %r27 = rocdl.smfmac.f32.32x32x64.fp8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
llvm.return %r0 : vector<4 x f32>
}
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 7be6d6ba4d7be..868597fba92a6 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -528,7 +528,12 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
%arg6 : vector<8 x i16>,
%arg7 : vector<2xi32>,
%arg8 : vector<4xi32>,
- %arg9 : vector<16xi32>) -> vector<4 x f32> {
+ %arg9 : vector<16xi32>,
+ %arg10 : vector<8 x f16>,
+ %arg11 : vector<16 x f16>,
+ %arg12 : vector<8 x bf16>,
+ %arg13 : vector<16 x bf16>,
+ %arg14 : vector<8 x i32>) -> vector<4 x f32> {
%csti32 = llvm.mlir.constant(42 : i32) : i32
// CHECK-LABEL: rocdl.smfmac
@@ -598,12 +603,81 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>
-
// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.fp8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
%r13 = rocdl.smfmac.f32.32x32x32.fp8.fp8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.f16(<8 x half> %{{.*}}, <16 x half> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r14 = rocdl.smfmac.f32.16x16x64.f16 %arg10, %arg11, %arg3, %csti32, %csti32, %csti32 :
+ (vector<8xf16>, vector<16xf16>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.f16(<8 x half> %{{.*}}, <16 x half> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r15 = rocdl.smfmac.f32.32x32x32.f16 %arg10, %arg11, %arg4, %csti32, %csti32, %csti32 :
+ (vector<8xf16>, vector<16xf16>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf16(<8 x bfloat> %{{.*}}, <16 x bfloat> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r16 = rocdl.smfmac.f32.16x16x64.bf16 %arg12, %arg13, %arg3, %csti32, %csti32, %csti32 :
+ (vector<8xbf16>, vector<16xbf16>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf16(<8 x bfloat> %{{.*}}, <16 x bfloat> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r17 = rocdl.smfmac.f32.32x32x32.bf16 %arg12, %arg13, %arg4, %csti32, %csti32, %csti32 :
+ (vector<8xbf16>, vector<16xbf16>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
+ // CHECK: call <4 x i32> @llvm.amdgcn.smfmac.i32.16x16x128.i8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 42, i32 42, i32 42)
+ %r18 = rocdl.smfmac.i32.16x16x128.i8 %arg8, %arg14, %arg8, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<4xi32>,
+ i32, i32, i32) -> vector<4xi32>
+
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.bf8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r19 = rocdl.smfmac.f32.16x16x128.bf8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.bf8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r20 = rocdl.smfmac.f32.16x16x128.bf8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.fp8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r21 = rocdl.smfmac.f32.16x16x128.fp8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.fp8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r22 = rocdl.smfmac.f32.16x16x128.fp8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: call <16 x i32> @llvm.amdgcn.smfmac.i32.32x32x64.i8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x i32> %{{.*}}, i32 42, i32 42, i32 42)
+ %r23 = rocdl.smfmac.i32.32x32x64.i8 %arg8, %arg14, %arg9, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<16xi32>,
+ i32, i32, i32) -> vector<16xi32>
+
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.bf8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r24 = rocdl.smfmac.f32.32x32x64.bf8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.bf8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r25 = rocdl.smfmac.f32.32x32x64.bf8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.fp8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r26 = rocdl.smfmac.f32.32x32x64.fp8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.fp8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r27 = rocdl.smfmac.f32.32x32x64.fp8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
llvm.return %r0 : vector<4 x f32>
}
>From c6638188980c9cee15b97116ebd908fb866c7a41 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Wed, 10 Dec 2025 15:49:58 -0800
Subject: [PATCH 2/2] nit plumbing
Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
mlir/test/Dialect/LLVMIR/rocdl.mlir | 37 ++++++++++++++---------------
mlir/test/Target/LLVMIR/rocdl.mlir | 37 ++++++++++++++---------------
2 files changed, 36 insertions(+), 38 deletions(-)
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 745fea8e38955..c179b110e32e7 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -289,11 +289,10 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
%arg7 : vector<2xi32>,
%arg8 : vector<4xi32>,
%arg9 : vector<16xi32>,
- %arg10 : vector<8 x f16>,
- %arg11 : vector<16 x f16>,
- %arg12 : vector<8 x bf16>,
- %arg13 : vector<16 x bf16>,
- %arg14 : vector<8 x i32>) -> vector<4 x f32> {
+ %arg10 : vector<16 x f16>,
+ %arg11 : vector<8 x bf16>,
+ %arg12 : vector<16 x bf16>,
+ %arg13 : vector<8 x i32>) -> vector<4 x f32> {
%csti32 = llvm.mlir.constant(42 : i32) : i32
// CHECK-LABEL: rocdl.smfmac
@@ -368,72 +367,72 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
i32, i32, i32) -> vector<16xf32>
// CHECK: rocdl.smfmac.f32.16x16x64.f16 %{{.*}} : (vector<8xf16>, vector<16xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- %r14 = rocdl.smfmac.f32.16x16x64.f16 %arg10, %arg11, %arg3, %csti32, %csti32, %csti32 :
+ %r14 = rocdl.smfmac.f32.16x16x64.f16 %arg2, %arg10, %arg3, %csti32, %csti32, %csti32 :
(vector<8xf16>, vector<16xf16>, vector<4xf32>,
i32, i32, i32) -> vector<4xf32>
// CHECK: rocdl.smfmac.f32.32x32x32.f16 %{{.*}} : (vector<8xf16>, vector<16xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- %r15 = rocdl.smfmac.f32.32x32x32.f16 %arg10, %arg11, %arg4, %csti32, %csti32, %csti32 :
+ %r15 = rocdl.smfmac.f32.32x32x32.f16 %arg2, %arg10, %arg4, %csti32, %csti32, %csti32 :
(vector<8xf16>, vector<16xf16>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>
// CHECK: rocdl.smfmac.f32.16x16x64.bf16 %{{.*}} : (vector<8xbf16>, vector<16xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- %r16 = rocdl.smfmac.f32.16x16x64.bf16 %arg12, %arg13, %arg3, %csti32, %csti32, %csti32 :
+ %r16 = rocdl.smfmac.f32.16x16x64.bf16 %arg11, %arg12, %arg3, %csti32, %csti32, %csti32 :
(vector<8xbf16>, vector<16xbf16>, vector<4xf32>,
i32, i32, i32) -> vector<4xf32>
// CHECK: rocdl.smfmac.f32.32x32x32.bf16 %{{.*}} : (vector<8xbf16>, vector<16xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- %r17 = rocdl.smfmac.f32.32x32x32.bf16 %arg12, %arg13, %arg4, %csti32, %csti32, %csti32 :
+ %r17 = rocdl.smfmac.f32.32x32x32.bf16 %arg11, %arg12, %arg4, %csti32, %csti32, %csti32 :
(vector<8xbf16>, vector<16xbf16>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>
// CHECK: rocdl.smfmac.i32.16x16x128.i8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
- %r18 = rocdl.smfmac.i32.16x16x128.i8 %arg8, %arg14, %arg8, %csti32, %csti32, %csti32 :
+ %r18 = rocdl.smfmac.i32.16x16x128.i8 %arg8, %arg13, %arg8, %csti32, %csti32, %csti32 :
(vector<4xi32>, vector<8xi32>, vector<4xi32>,
i32, i32, i32) -> vector<4xi32>
// CHECK: rocdl.smfmac.f32.16x16x128.bf8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- %r19 = rocdl.smfmac.f32.16x16x128.bf8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ %r19 = rocdl.smfmac.f32.16x16x128.bf8.bf8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
(vector<4xi32>, vector<8xi32>, vector<4xf32>,
i32, i32, i32) -> vector<4xf32>
// CHECK: rocdl.smfmac.f32.16x16x128.bf8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- %r20 = rocdl.smfmac.f32.16x16x128.bf8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ %r20 = rocdl.smfmac.f32.16x16x128.bf8.fp8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
(vector<4xi32>, vector<8xi32>, vector<4xf32>,
i32, i32, i32) -> vector<4xf32>
// CHECK: rocdl.smfmac.f32.16x16x128.fp8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- %r21 = rocdl.smfmac.f32.16x16x128.fp8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ %r21 = rocdl.smfmac.f32.16x16x128.fp8.bf8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
(vector<4xi32>, vector<8xi32>, vector<4xf32>,
i32, i32, i32) -> vector<4xf32>
// CHECK: rocdl.smfmac.f32.16x16x128.fp8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- %r22 = rocdl.smfmac.f32.16x16x128.fp8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ %r22 = rocdl.smfmac.f32.16x16x128.fp8.fp8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
(vector<4xi32>, vector<8xi32>, vector<4xf32>,
i32, i32, i32) -> vector<4xf32>
// CHECK: rocdl.smfmac.i32.32x32x64.i8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
- %r23 = rocdl.smfmac.i32.32x32x64.i8 %arg8, %arg14, %arg9, %csti32, %csti32, %csti32 :
+ %r23 = rocdl.smfmac.i32.32x32x64.i8 %arg8, %arg13, %arg9, %csti32, %csti32, %csti32 :
(vector<4xi32>, vector<8xi32>, vector<16xi32>,
i32, i32, i32) -> vector<16xi32>
// CHECK: rocdl.smfmac.f32.32x32x64.bf8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- %r24 = rocdl.smfmac.f32.32x32x64.bf8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ %r24 = rocdl.smfmac.f32.32x32x64.bf8.bf8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
(vector<4xi32>, vector<8xi32>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>
// CHECK: rocdl.smfmac.f32.32x32x64.bf8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- %r25 = rocdl.smfmac.f32.32x32x64.bf8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ %r25 = rocdl.smfmac.f32.32x32x64.bf8.fp8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
(vector<4xi32>, vector<8xi32>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>
// CHECK: rocdl.smfmac.f32.32x32x64.fp8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- %r26 = rocdl.smfmac.f32.32x32x64.fp8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ %r26 = rocdl.smfmac.f32.32x32x64.fp8.bf8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
(vector<4xi32>, vector<8xi32>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>
// CHECK: rocdl.smfmac.f32.32x32x64.fp8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- %r27 = rocdl.smfmac.f32.32x32x64.fp8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ %r27 = rocdl.smfmac.f32.32x32x64.fp8.fp8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
(vector<4xi32>, vector<8xi32>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 868597fba92a6..a7f16a5054626 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -529,11 +529,10 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
%arg7 : vector<2xi32>,
%arg8 : vector<4xi32>,
%arg9 : vector<16xi32>,
- %arg10 : vector<8 x f16>,
- %arg11 : vector<16 x f16>,
- %arg12 : vector<8 x bf16>,
- %arg13 : vector<16 x bf16>,
- %arg14 : vector<8 x i32>) -> vector<4 x f32> {
+ %arg10 : vector<16 x f16>,
+ %arg11 : vector<8 x bf16>,
+ %arg12 : vector<16 x bf16>,
+ %arg13 : vector<8 x i32>) -> vector<4 x f32> {
%csti32 = llvm.mlir.constant(42 : i32) : i32
// CHECK-LABEL: rocdl.smfmac
@@ -609,72 +608,72 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
i32, i32, i32) -> vector<16xf32>
// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.f16(<8 x half> %{{.*}}, <16 x half> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r14 = rocdl.smfmac.f32.16x16x64.f16 %arg10, %arg11, %arg3, %csti32, %csti32, %csti32 :
+ %r14 = rocdl.smfmac.f32.16x16x64.f16 %arg2, %arg10, %arg3, %csti32, %csti32, %csti32 :
(vector<8xf16>, vector<16xf16>, vector<4xf32>,
i32, i32, i32) -> vector<4xf32>
// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.f16(<8 x half> %{{.*}}, <16 x half> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r15 = rocdl.smfmac.f32.32x32x32.f16 %arg10, %arg11, %arg4, %csti32, %csti32, %csti32 :
+ %r15 = rocdl.smfmac.f32.32x32x32.f16 %arg2, %arg10, %arg4, %csti32, %csti32, %csti32 :
(vector<8xf16>, vector<16xf16>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>
// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf16(<8 x bfloat> %{{.*}}, <16 x bfloat> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r16 = rocdl.smfmac.f32.16x16x64.bf16 %arg12, %arg13, %arg3, %csti32, %csti32, %csti32 :
+ %r16 = rocdl.smfmac.f32.16x16x64.bf16 %arg11, %arg12, %arg3, %csti32, %csti32, %csti32 :
(vector<8xbf16>, vector<16xbf16>, vector<4xf32>,
i32, i32, i32) -> vector<4xf32>
// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf16(<8 x bfloat> %{{.*}}, <16 x bfloat> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r17 = rocdl.smfmac.f32.32x32x32.bf16 %arg12, %arg13, %arg4, %csti32, %csti32, %csti32 :
+ %r17 = rocdl.smfmac.f32.32x32x32.bf16 %arg11, %arg12, %arg4, %csti32, %csti32, %csti32 :
(vector<8xbf16>, vector<16xbf16>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>
// CHECK: call <4 x i32> @llvm.amdgcn.smfmac.i32.16x16x128.i8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 42, i32 42, i32 42)
- %r18 = rocdl.smfmac.i32.16x16x128.i8 %arg8, %arg14, %arg8, %csti32, %csti32, %csti32 :
+ %r18 = rocdl.smfmac.i32.16x16x128.i8 %arg8, %arg13, %arg8, %csti32, %csti32, %csti32 :
(vector<4xi32>, vector<8xi32>, vector<4xi32>,
i32, i32, i32) -> vector<4xi32>
// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.bf8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r19 = rocdl.smfmac.f32.16x16x128.bf8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ %r19 = rocdl.smfmac.f32.16x16x128.bf8.bf8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
(vector<4xi32>, vector<8xi32>, vector<4xf32>,
i32, i32, i32) -> vector<4xf32>
// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.bf8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r20 = rocdl.smfmac.f32.16x16x128.bf8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ %r20 = rocdl.smfmac.f32.16x16x128.bf8.fp8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
(vector<4xi32>, vector<8xi32>, vector<4xf32>,
i32, i32, i32) -> vector<4xf32>
// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.fp8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r21 = rocdl.smfmac.f32.16x16x128.fp8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ %r21 = rocdl.smfmac.f32.16x16x128.fp8.bf8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
(vector<4xi32>, vector<8xi32>, vector<4xf32>,
i32, i32, i32) -> vector<4xf32>
// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.fp8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r22 = rocdl.smfmac.f32.16x16x128.fp8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ %r22 = rocdl.smfmac.f32.16x16x128.fp8.fp8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
(vector<4xi32>, vector<8xi32>, vector<4xf32>,
i32, i32, i32) -> vector<4xf32>
// CHECK: call <16 x i32> @llvm.amdgcn.smfmac.i32.32x32x64.i8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x i32> %{{.*}}, i32 42, i32 42, i32 42)
- %r23 = rocdl.smfmac.i32.32x32x64.i8 %arg8, %arg14, %arg9, %csti32, %csti32, %csti32 :
+ %r23 = rocdl.smfmac.i32.32x32x64.i8 %arg8, %arg13, %arg9, %csti32, %csti32, %csti32 :
(vector<4xi32>, vector<8xi32>, vector<16xi32>,
i32, i32, i32) -> vector<16xi32>
// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.bf8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r24 = rocdl.smfmac.f32.32x32x64.bf8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ %r24 = rocdl.smfmac.f32.32x32x64.bf8.bf8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
(vector<4xi32>, vector<8xi32>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>
// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.bf8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r25 = rocdl.smfmac.f32.32x32x64.bf8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ %r25 = rocdl.smfmac.f32.32x32x64.bf8.fp8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
(vector<4xi32>, vector<8xi32>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>
// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.fp8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r26 = rocdl.smfmac.f32.32x32x64.fp8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ %r26 = rocdl.smfmac.f32.32x32x64.fp8.bf8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
(vector<4xi32>, vector<8xi32>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>
// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.fp8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
- %r27 = rocdl.smfmac.f32.32x32x64.fp8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ %r27 = rocdl.smfmac.f32.32x32x64.fp8.fp8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
(vector<4xi32>, vector<8xi32>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>
More information about the Mlir-commits
mailing list