[Mlir-commits] [mlir] [mlir][rocdl] add gfx950 smfmac instructions to rocdl dialect (PR #171737)

Eric Feng llvmlistbot at llvm.org
Wed Dec 10 15:50:11 PST 2025


https://github.com/efric updated https://github.com/llvm/llvm-project/pull/171737

>From e50fb1b99cfe5e1024dee010a553826d94a5d223 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Wed, 10 Dec 2025 14:33:51 -0800
Subject: [PATCH 1/2] add gfx950 smfmac to rocdl

Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
 mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 15 ++++
 mlir/test/Dialect/LLVMIR/rocdl.mlir          | 77 ++++++++++++++++++-
 mlir/test/Target/LLVMIR/rocdl.mlir           | 78 +++++++++++++++++++-
 3 files changed, 167 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 0edb208a8fcba..fe8a854cd1321 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -592,6 +592,21 @@ def ROCDL_smfmac_f32_32x32x32_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.b
 def ROCDL_smfmac_f32_32x32x32_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf8.fp8">;
 def ROCDL_smfmac_f32_32x32x32_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.bf8">;
 def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.fp8">;
+// New in gfx950.
+def ROCDL_smfmac_f32_16x16x64_bf16 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.bf16">;
+def ROCDL_smfmac_f32_16x16x64_f16 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.f16">;
+def ROCDL_smfmac_i32_16x16x128_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.16x16x128.i8">;
+def ROCDL_smfmac_f32_16x16x128_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.bf8.bf8">;
+def ROCDL_smfmac_f32_16x16x128_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.bf8.fp8">;
+def ROCDL_smfmac_f32_16x16x128_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.fp8.bf8">;
+def ROCDL_smfmac_f32_16x16x128_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.fp8.fp8">;
+def ROCDL_smfmac_f32_32x32x32_bf16 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf16">;
+def ROCDL_smfmac_f32_32x32x32_f16 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.f16">;
+def ROCDL_smfmac_i32_32x32x64_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.32x32x64.i8">;
+def ROCDL_smfmac_f32_32x32x64_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.bf8.bf8">;
+def ROCDL_smfmac_f32_32x32x64_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.bf8.fp8">;
+def ROCDL_smfmac_f32_32x32x64_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.fp8.bf8">;
+def ROCDL_smfmac_f32_32x32x64_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.fp8.fp8">;
 
 
 //===---------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 1b50feea418b6..745fea8e38955 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -288,7 +288,12 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
                    %arg6 : vector<8 x i16>,
                    %arg7 : vector<2xi32>,
                    %arg8 : vector<4xi32>,
-                   %arg9 : vector<16xi32>) -> vector<4 x f32> {
+                   %arg9 : vector<16xi32>,
+                   %arg10 : vector<8 x f16>,
+                   %arg11 : vector<16 x f16>,
+                   %arg12 : vector<8 x bf16>,
+                   %arg13 : vector<16 x bf16>,
+                   %arg14 : vector<8 x i32>) -> vector<4 x f32> {
   %csti32 = llvm.mlir.constant(42 : i32) : i32
 
   // CHECK-LABEL: rocdl.smfmac
@@ -362,6 +367,76 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
                                 (vector<2xi32>, vector<4xi32>, vector<16xf32>,
                                  i32, i32, i32) -> vector<16xf32>
 
+  // CHECK: rocdl.smfmac.f32.16x16x64.f16 %{{.*}} : (vector<8xf16>, vector<16xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  %r14 = rocdl.smfmac.f32.16x16x64.f16 %arg10, %arg11, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<8xf16>, vector<16xf16>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x32.f16 %{{.*}} : (vector<8xf16>, vector<16xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  %r15 = rocdl.smfmac.f32.32x32x32.f16 %arg10, %arg11, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<8xf16>, vector<16xf16>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x64.bf16 %{{.*}} : (vector<8xbf16>, vector<16xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  %r16 = rocdl.smfmac.f32.16x16x64.bf16 %arg12, %arg13, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<8xbf16>, vector<16xbf16>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x32.bf16 %{{.*}} : (vector<8xbf16>, vector<16xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  %r17 = rocdl.smfmac.f32.32x32x32.bf16 %arg12, %arg13, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<8xbf16>, vector<16xbf16>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: rocdl.smfmac.i32.16x16x128.i8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+  %r18 = rocdl.smfmac.i32.16x16x128.i8 %arg8, %arg14, %arg8, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<4xi32>,
+                                 i32, i32, i32) -> vector<4xi32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x128.bf8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  %r19 = rocdl.smfmac.f32.16x16x128.bf8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x128.bf8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  %r20 = rocdl.smfmac.f32.16x16x128.bf8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x128.fp8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  %r21 = rocdl.smfmac.f32.16x16x128.fp8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x128.fp8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  %r22 = rocdl.smfmac.f32.16x16x128.fp8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: rocdl.smfmac.i32.32x32x64.i8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
+  %r23 = rocdl.smfmac.i32.32x32x64.i8 %arg8, %arg14, %arg9, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<16xi32>,
+                                 i32, i32, i32) -> vector<16xi32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x64.bf8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  %r24 = rocdl.smfmac.f32.32x32x64.bf8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x64.bf8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  %r25 = rocdl.smfmac.f32.32x32x64.bf8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x64.fp8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  %r26 = rocdl.smfmac.f32.32x32x64.fp8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x64.fp8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  %r27 = rocdl.smfmac.f32.32x32x64.fp8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
   llvm.return %r0 : vector<4 x f32>
 }
 
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 7be6d6ba4d7be..868597fba92a6 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -528,7 +528,12 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
                    %arg6 : vector<8 x i16>,
                    %arg7 : vector<2xi32>,
                    %arg8 : vector<4xi32>,
-                   %arg9 : vector<16xi32>) -> vector<4 x f32> {
+                   %arg9 : vector<16xi32>,
+                   %arg10 : vector<8 x f16>,
+                   %arg11 : vector<16 x f16>,
+                   %arg12 : vector<8 x bf16>,
+                   %arg13 : vector<16 x bf16>,
+                   %arg14 : vector<8 x i32>) -> vector<4 x f32> {
   %csti32 = llvm.mlir.constant(42 : i32) : i32
 
   // CHECK-LABEL: rocdl.smfmac
@@ -598,12 +603,81 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
                                 (vector<2xi32>, vector<4xi32>, vector<16xf32>,
                                  i32, i32, i32) -> vector<16xf32>
 
-
   // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.fp8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
   %r13 = rocdl.smfmac.f32.32x32x32.fp8.fp8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
                                 (vector<2xi32>, vector<4xi32>, vector<16xf32>,
                                  i32, i32, i32) -> vector<16xf32>
 
+  // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.f16(<8 x half> %{{.*}}, <16 x half> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r14 = rocdl.smfmac.f32.16x16x64.f16 %arg10, %arg11, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<8xf16>, vector<16xf16>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.f16(<8 x half> %{{.*}}, <16 x half> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r15 = rocdl.smfmac.f32.32x32x32.f16 %arg10, %arg11, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<8xf16>, vector<16xf16>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf16(<8 x bfloat> %{{.*}}, <16 x bfloat> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r16 = rocdl.smfmac.f32.16x16x64.bf16 %arg12, %arg13, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<8xbf16>, vector<16xbf16>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf16(<8 x bfloat> %{{.*}}, <16 x bfloat> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r17 = rocdl.smfmac.f32.32x32x32.bf16 %arg12, %arg13, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<8xbf16>, vector<16xbf16>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: call <4 x i32> @llvm.amdgcn.smfmac.i32.16x16x128.i8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 42, i32 42, i32 42)
+  %r18 = rocdl.smfmac.i32.16x16x128.i8 %arg8, %arg14, %arg8, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<4xi32>,
+                                 i32, i32, i32) -> vector<4xi32>
+
+  // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.bf8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r19 = rocdl.smfmac.f32.16x16x128.bf8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.bf8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r20 = rocdl.smfmac.f32.16x16x128.bf8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.fp8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r21 = rocdl.smfmac.f32.16x16x128.fp8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.fp8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r22 = rocdl.smfmac.f32.16x16x128.fp8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: call <16 x i32> @llvm.amdgcn.smfmac.i32.32x32x64.i8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x i32> %{{.*}}, i32 42, i32 42, i32 42)
+  %r23 = rocdl.smfmac.i32.32x32x64.i8 %arg8, %arg14, %arg9, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<16xi32>,
+                                 i32, i32, i32) -> vector<16xi32>
+
+  // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.bf8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r24 = rocdl.smfmac.f32.32x32x64.bf8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.bf8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r25 = rocdl.smfmac.f32.32x32x64.bf8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.fp8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r26 = rocdl.smfmac.f32.32x32x64.fp8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.fp8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r27 = rocdl.smfmac.f32.32x32x64.fp8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
   llvm.return %r0 : vector<4 x f32>
 }
 

>From c6638188980c9cee15b97116ebd908fb866c7a41 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Wed, 10 Dec 2025 15:49:58 -0800
Subject: [PATCH 2/2] nit plumbing

Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
 mlir/test/Dialect/LLVMIR/rocdl.mlir | 37 ++++++++++++++---------------
 mlir/test/Target/LLVMIR/rocdl.mlir  | 37 ++++++++++++++---------------
 2 files changed, 36 insertions(+), 38 deletions(-)

diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 745fea8e38955..c179b110e32e7 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -289,11 +289,10 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
                    %arg7 : vector<2xi32>,
                    %arg8 : vector<4xi32>,
                    %arg9 : vector<16xi32>,
-                   %arg10 : vector<8 x f16>,
-                   %arg11 : vector<16 x f16>,
-                   %arg12 : vector<8 x bf16>,
-                   %arg13 : vector<16 x bf16>,
-                   %arg14 : vector<8 x i32>) -> vector<4 x f32> {
+                   %arg10 : vector<16 x f16>,
+                   %arg11 : vector<8 x bf16>,
+                   %arg12 : vector<16 x bf16>,
+                   %arg13 : vector<8 x i32>) -> vector<4 x f32> {
   %csti32 = llvm.mlir.constant(42 : i32) : i32
 
   // CHECK-LABEL: rocdl.smfmac
@@ -368,72 +367,72 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
                                  i32, i32, i32) -> vector<16xf32>
 
   // CHECK: rocdl.smfmac.f32.16x16x64.f16 %{{.*}} : (vector<8xf16>, vector<16xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
-  %r14 = rocdl.smfmac.f32.16x16x64.f16 %arg10, %arg11, %arg3, %csti32, %csti32, %csti32 :
+  %r14 = rocdl.smfmac.f32.16x16x64.f16 %arg2, %arg10, %arg3, %csti32, %csti32, %csti32 :
                                 (vector<8xf16>, vector<16xf16>, vector<4xf32>,
                                  i32, i32, i32) -> vector<4xf32>
 
   // CHECK: rocdl.smfmac.f32.32x32x32.f16 %{{.*}} : (vector<8xf16>, vector<16xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
-  %r15 = rocdl.smfmac.f32.32x32x32.f16 %arg10, %arg11, %arg4, %csti32, %csti32, %csti32 :
+  %r15 = rocdl.smfmac.f32.32x32x32.f16 %arg2, %arg10, %arg4, %csti32, %csti32, %csti32 :
                                 (vector<8xf16>, vector<16xf16>, vector<16xf32>,
                                  i32, i32, i32) -> vector<16xf32>
 
   // CHECK: rocdl.smfmac.f32.16x16x64.bf16 %{{.*}} : (vector<8xbf16>, vector<16xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
-  %r16 = rocdl.smfmac.f32.16x16x64.bf16 %arg12, %arg13, %arg3, %csti32, %csti32, %csti32 :
+  %r16 = rocdl.smfmac.f32.16x16x64.bf16 %arg11, %arg12, %arg3, %csti32, %csti32, %csti32 :
                                 (vector<8xbf16>, vector<16xbf16>, vector<4xf32>,
                                  i32, i32, i32) -> vector<4xf32>
 
   // CHECK: rocdl.smfmac.f32.32x32x32.bf16 %{{.*}} : (vector<8xbf16>, vector<16xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
-  %r17 = rocdl.smfmac.f32.32x32x32.bf16 %arg12, %arg13, %arg4, %csti32, %csti32, %csti32 :
+  %r17 = rocdl.smfmac.f32.32x32x32.bf16 %arg11, %arg12, %arg4, %csti32, %csti32, %csti32 :
                                 (vector<8xbf16>, vector<16xbf16>, vector<16xf32>,
                                  i32, i32, i32) -> vector<16xf32>
 
   // CHECK: rocdl.smfmac.i32.16x16x128.i8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
-  %r18 = rocdl.smfmac.i32.16x16x128.i8 %arg8, %arg14, %arg8, %csti32, %csti32, %csti32 :
+  %r18 = rocdl.smfmac.i32.16x16x128.i8 %arg8, %arg13, %arg8, %csti32, %csti32, %csti32 :
                                 (vector<4xi32>, vector<8xi32>, vector<4xi32>,
                                  i32, i32, i32) -> vector<4xi32>
 
   // CHECK: rocdl.smfmac.f32.16x16x128.bf8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
-  %r19 = rocdl.smfmac.f32.16x16x128.bf8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+  %r19 = rocdl.smfmac.f32.16x16x128.bf8.bf8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
                                 (vector<4xi32>, vector<8xi32>, vector<4xf32>,
                                  i32, i32, i32) -> vector<4xf32>
 
   // CHECK: rocdl.smfmac.f32.16x16x128.bf8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
-  %r20 = rocdl.smfmac.f32.16x16x128.bf8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+  %r20 = rocdl.smfmac.f32.16x16x128.bf8.fp8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
                                 (vector<4xi32>, vector<8xi32>, vector<4xf32>,
                                  i32, i32, i32) -> vector<4xf32>
 
   // CHECK: rocdl.smfmac.f32.16x16x128.fp8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
-  %r21 = rocdl.smfmac.f32.16x16x128.fp8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+  %r21 = rocdl.smfmac.f32.16x16x128.fp8.bf8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
                                 (vector<4xi32>, vector<8xi32>, vector<4xf32>,
                                  i32, i32, i32) -> vector<4xf32>
 
   // CHECK: rocdl.smfmac.f32.16x16x128.fp8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
-  %r22 = rocdl.smfmac.f32.16x16x128.fp8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+  %r22 = rocdl.smfmac.f32.16x16x128.fp8.fp8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
                                 (vector<4xi32>, vector<8xi32>, vector<4xf32>,
                                  i32, i32, i32) -> vector<4xf32>
 
   // CHECK: rocdl.smfmac.i32.32x32x64.i8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
-  %r23 = rocdl.smfmac.i32.32x32x64.i8 %arg8, %arg14, %arg9, %csti32, %csti32, %csti32 :
+  %r23 = rocdl.smfmac.i32.32x32x64.i8 %arg8, %arg13, %arg9, %csti32, %csti32, %csti32 :
                                 (vector<4xi32>, vector<8xi32>, vector<16xi32>,
                                  i32, i32, i32) -> vector<16xi32>
 
   // CHECK: rocdl.smfmac.f32.32x32x64.bf8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
-  %r24 = rocdl.smfmac.f32.32x32x64.bf8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+  %r24 = rocdl.smfmac.f32.32x32x64.bf8.bf8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
                                 (vector<4xi32>, vector<8xi32>, vector<16xf32>,
                                  i32, i32, i32) -> vector<16xf32>
 
   // CHECK: rocdl.smfmac.f32.32x32x64.bf8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
-  %r25 = rocdl.smfmac.f32.32x32x64.bf8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+  %r25 = rocdl.smfmac.f32.32x32x64.bf8.fp8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
                                 (vector<4xi32>, vector<8xi32>, vector<16xf32>,
                                  i32, i32, i32) -> vector<16xf32>
 
   // CHECK: rocdl.smfmac.f32.32x32x64.fp8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
-  %r26 = rocdl.smfmac.f32.32x32x64.fp8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+  %r26 = rocdl.smfmac.f32.32x32x64.fp8.bf8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
                                 (vector<4xi32>, vector<8xi32>, vector<16xf32>,
                                  i32, i32, i32) -> vector<16xf32>
 
   // CHECK: rocdl.smfmac.f32.32x32x64.fp8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
-  %r27 = rocdl.smfmac.f32.32x32x64.fp8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+  %r27 = rocdl.smfmac.f32.32x32x64.fp8.fp8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
                                 (vector<4xi32>, vector<8xi32>, vector<16xf32>,
                                  i32, i32, i32) -> vector<16xf32>
 
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 868597fba92a6..a7f16a5054626 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -529,11 +529,10 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
                    %arg7 : vector<2xi32>,
                    %arg8 : vector<4xi32>,
                    %arg9 : vector<16xi32>,
-                   %arg10 : vector<8 x f16>,
-                   %arg11 : vector<16 x f16>,
-                   %arg12 : vector<8 x bf16>,
-                   %arg13 : vector<16 x bf16>,
-                   %arg14 : vector<8 x i32>) -> vector<4 x f32> {
+                   %arg10 : vector<16 x f16>,
+                   %arg11 : vector<8 x bf16>,
+                   %arg12 : vector<16 x bf16>,
+                   %arg13 : vector<8 x i32>) -> vector<4 x f32> {
   %csti32 = llvm.mlir.constant(42 : i32) : i32
 
   // CHECK-LABEL: rocdl.smfmac
@@ -609,72 +608,72 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
                                  i32, i32, i32) -> vector<16xf32>
 
   // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.f16(<8 x half> %{{.*}}, <16 x half> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
-  %r14 = rocdl.smfmac.f32.16x16x64.f16 %arg10, %arg11, %arg3, %csti32, %csti32, %csti32 :
+  %r14 = rocdl.smfmac.f32.16x16x64.f16 %arg2, %arg10, %arg3, %csti32, %csti32, %csti32 :
                                 (vector<8xf16>, vector<16xf16>, vector<4xf32>,
                                  i32, i32, i32) -> vector<4xf32>
 
   // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.f16(<8 x half> %{{.*}}, <16 x half> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
-  %r15 = rocdl.smfmac.f32.32x32x32.f16 %arg10, %arg11, %arg4, %csti32, %csti32, %csti32 :
+  %r15 = rocdl.smfmac.f32.32x32x32.f16 %arg2, %arg10, %arg4, %csti32, %csti32, %csti32 :
                                 (vector<8xf16>, vector<16xf16>, vector<16xf32>,
                                  i32, i32, i32) -> vector<16xf32>
 
   // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf16(<8 x bfloat> %{{.*}}, <16 x bfloat> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
-  %r16 = rocdl.smfmac.f32.16x16x64.bf16 %arg12, %arg13, %arg3, %csti32, %csti32, %csti32 :
+  %r16 = rocdl.smfmac.f32.16x16x64.bf16 %arg11, %arg12, %arg3, %csti32, %csti32, %csti32 :
                                 (vector<8xbf16>, vector<16xbf16>, vector<4xf32>,
                                  i32, i32, i32) -> vector<4xf32>
 
   // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf16(<8 x bfloat> %{{.*}}, <16 x bfloat> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
-  %r17 = rocdl.smfmac.f32.32x32x32.bf16 %arg12, %arg13, %arg4, %csti32, %csti32, %csti32 :
+  %r17 = rocdl.smfmac.f32.32x32x32.bf16 %arg11, %arg12, %arg4, %csti32, %csti32, %csti32 :
                                 (vector<8xbf16>, vector<16xbf16>, vector<16xf32>,
                                  i32, i32, i32) -> vector<16xf32>
 
   // CHECK: call <4 x i32> @llvm.amdgcn.smfmac.i32.16x16x128.i8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 42, i32 42, i32 42)
-  %r18 = rocdl.smfmac.i32.16x16x128.i8 %arg8, %arg14, %arg8, %csti32, %csti32, %csti32 :
+  %r18 = rocdl.smfmac.i32.16x16x128.i8 %arg8, %arg13, %arg8, %csti32, %csti32, %csti32 :
                                 (vector<4xi32>, vector<8xi32>, vector<4xi32>,
                                  i32, i32, i32) -> vector<4xi32>
 
   // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.bf8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
-  %r19 = rocdl.smfmac.f32.16x16x128.bf8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+  %r19 = rocdl.smfmac.f32.16x16x128.bf8.bf8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
                                 (vector<4xi32>, vector<8xi32>, vector<4xf32>,
                                  i32, i32, i32) -> vector<4xf32>
 
   // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.bf8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
-  %r20 = rocdl.smfmac.f32.16x16x128.bf8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+  %r20 = rocdl.smfmac.f32.16x16x128.bf8.fp8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
                                 (vector<4xi32>, vector<8xi32>, vector<4xf32>,
                                  i32, i32, i32) -> vector<4xf32>
 
   // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.fp8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
-  %r21 = rocdl.smfmac.f32.16x16x128.fp8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+  %r21 = rocdl.smfmac.f32.16x16x128.fp8.bf8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
                                 (vector<4xi32>, vector<8xi32>, vector<4xf32>,
                                  i32, i32, i32) -> vector<4xf32>
 
   // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.fp8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
-  %r22 = rocdl.smfmac.f32.16x16x128.fp8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+  %r22 = rocdl.smfmac.f32.16x16x128.fp8.fp8 %arg8, %arg13, %arg3, %csti32, %csti32, %csti32 :
                                 (vector<4xi32>, vector<8xi32>, vector<4xf32>,
                                  i32, i32, i32) -> vector<4xf32>
 
   // CHECK: call <16 x i32> @llvm.amdgcn.smfmac.i32.32x32x64.i8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x i32> %{{.*}}, i32 42, i32 42, i32 42)
-  %r23 = rocdl.smfmac.i32.32x32x64.i8 %arg8, %arg14, %arg9, %csti32, %csti32, %csti32 :
+  %r23 = rocdl.smfmac.i32.32x32x64.i8 %arg8, %arg13, %arg9, %csti32, %csti32, %csti32 :
                                 (vector<4xi32>, vector<8xi32>, vector<16xi32>,
                                  i32, i32, i32) -> vector<16xi32>
 
   // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.bf8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
-  %r24 = rocdl.smfmac.f32.32x32x64.bf8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+  %r24 = rocdl.smfmac.f32.32x32x64.bf8.bf8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
                                 (vector<4xi32>, vector<8xi32>, vector<16xf32>,
                                  i32, i32, i32) -> vector<16xf32>
 
   // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.bf8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
-  %r25 = rocdl.smfmac.f32.32x32x64.bf8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+  %r25 = rocdl.smfmac.f32.32x32x64.bf8.fp8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
                                 (vector<4xi32>, vector<8xi32>, vector<16xf32>,
                                  i32, i32, i32) -> vector<16xf32>
 
   // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.fp8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
-  %r26 = rocdl.smfmac.f32.32x32x64.fp8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+  %r26 = rocdl.smfmac.f32.32x32x64.fp8.bf8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
                                 (vector<4xi32>, vector<8xi32>, vector<16xf32>,
                                  i32, i32, i32) -> vector<16xf32>
 
   // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.fp8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
-  %r27 = rocdl.smfmac.f32.32x32x64.fp8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+  %r27 = rocdl.smfmac.f32.32x32x64.fp8.fp8 %arg8, %arg13, %arg4, %csti32, %csti32, %csti32 :
                                 (vector<4xi32>, vector<8xi32>, vector<16xf32>,
                                  i32, i32, i32) -> vector<16xf32>
 



More information about the Mlir-commits mailing list