[Mlir-commits] [mlir] [mlir][ROCDL] adds wmma scaled intrinsics for gfx1250 (PR #165915)

Muzammiluddin Syed llvmlistbot at llvm.org
Fri Oct 31 16:24:41 PDT 2025


https://github.com/Muzammiluddin-Syed-ECE updated https://github.com/llvm/llvm-project/pull/165915

>From ba086bdc9dfdbdec43dee8e06af82d5ce80f85f6 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Tue, 7 Oct 2025 13:28:29 -0500
Subject: [PATCH 1/2] [mlir][ROCDL] adds wmma scaled intrinsics for gfx1250

Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
 mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td |   2 +
 mlir/test/Target/LLVMIR/rocdl.mlir           | 218 +++++++++++++++++++
 2 files changed, 220 insertions(+)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 5241f9a6f2b43..f14fb42402913 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -627,6 +627,8 @@ def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8
 def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_fp8", [0]>;
 def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_bf8", [0]>;
 def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x64.iu8", [1]>;
+def ROCDL_wmma_scale_f32_16x16x128_f8f6f4   : ROCDL_Wmma_IntrOp<"wmma.scale.f32.16x16x128.f8f6f4", [1,3]>;
+def ROCDL_wmma_scale16_f32_16x16x128_f8f6f4   : ROCDL_Wmma_IntrOp<"wmma.scale16.f32.16x16x128.f8f6f4", [1,3]>;
 
 //===---------------------------------------------------------------------===//
 // LDS transpose intrinsics (available in GFX950)
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 8a848221a50dd..3c985a819ba0b 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1028,6 +1028,224 @@ llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
   llvm.return %r3 : vector<4xf16>
 }
 
+llvm.func @rocdl.wmma.scale.f32.16x16x128.f8f6f4(%arg0 : i32,
+                   %arg1 : vector<4 x f32>, %arg2 : vector<8xi32>,
+                   %arg3 : vector<12xi32>, %arg4 : vector<4xi32>, 
+                   %arg5 : vector<16xi32>, %arg6 : vector<12xi32>,
+                   %arg7 : vector<8xi64>, %arg8 : i64) -> vector<4 x f32> {
+  %cst0 = llvm.mlir.constant(0 : i32) : i32
+  %cst1 = llvm.mlir.constant(1 : i32) : i32
+  %cst2 = llvm.mlir.constant(2 : i32) : i32
+  %cst3 = llvm.mlir.constant(3 : i32) : i32
+  %cst4 = llvm.mlir.constant(4 : i32) : i32
+  %cst0_i16 = llvm.mlir.constant(0 : i16) : i16
+  %zero = llvm.mlir.constant(false) : i1
+  // CHECK-LABEL: rocdl.wmma.scale.f32.16x16x128.f8f6f4
+  
+  
+  // fp8 * fp8
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 {{.*}}, <16 x i32> %{{.*}}, i32 {{.*}}, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 {{.*}}, <16 x i32> %{{.*}}, i32 {{.*}}, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  %r00 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %cst0, %arg5, %cst0, %arg5, %cst0_i16, %arg1, %cst0, %cst0, %arg0, %cst0, %cst0, %arg0, %zero, %zero :
+                              (i32, vector<16xi32>, i32, vector<16xi32>, i16, vector<4xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<4xf32>
+  %r00_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %cst0, %arg5, %cst0, %arg5, %cst0_i16, %arg1, %cst0, %cst0, %arg8, %cst0, %cst0, %arg8, %zero, %zero :
+                              (i32, vector<16xi32>, i32, vector<16xi32>, i16, vector<4xf32>, i32, i32, i64, i32, i32, i64, i1, i1) -> vector<4xf32>                            
+
+  // fp8 * bf8
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 {{.*}}, <16 x i32> %{{.*}}, i32 {{.*}}, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 {{.*}}, <16 x i32> %{{.*}}, i32 {{.*}}, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  %r01 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %cst0, %arg5, %cst1, %arg5, %cst0_i16, %arg1, %cst0, %cst0, %arg0, %cst0, %cst0, %arg0, %zero, %zero :
+                              (i32, vector<16xi32>, i32, vector<16xi32>, i16, vector<4xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<4xf32>
+  %r01_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %cst0, %arg5, %cst1, %arg5, %cst0_i16, %arg1, %cst0, %cst0, %arg8, %cst0, %cst0, %arg8, %zero, %zero :
+                              (i32, vector<16xi32>, i32, vector<16xi32>, i16, vector<4xf32>, i32, i32, i64, i32, i32, i64, i1, i1) -> vector<4xf32>                            
+
+  // fp8 * fp6
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v12i32(i32 {{.*}}, <16 x i32> %{{.*}}, i32 {{.*}}, <12 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v16i32.v12i32(i32 {{.*}}, <16 x i32> %{{.*}}, i32 {{.*}}, <12 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  %r02 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %cst0, %arg5, %cst2, %arg3, %cst0_i16, %arg1, %cst0, %cst0, %arg0, %cst0, %cst0, %arg0, %zero, %zero :
+                              (i32, vector<16xi32>, i32, vector<12xi32>, i16, vector<4xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<4xf32>
+  %r02_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %cst0, %arg5, %cst2, %arg3, %cst0_i16, %arg1, %cst0, %cst0, %arg8, %cst0, %cst0, %arg8, %zero, %zero :
+                              (i32, vector<16xi32>, i32, vector<12xi32>, i16, vector<4xf32>, i32, i32, i64, i32, i32, i64, i1, i1) -> vector<4xf32>                            
+
+  // fp8 * bf6
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v12i32(i32 {{.*}}, <16 x i32> %{{.*}}, i32 {{.*}}, <12 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v16i32.v12i32(i32 {{.*}}, <16 x i32> %{{.*}}, i32 {{.*}}, <12 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  %r03 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %cst0, %arg5, %cst3, %arg3, %cst0_i16, %arg1, %cst0, %cst0, %arg0, %cst0, %cst0, %arg0, %zero, %zero :
+                              (i32, vector<16xi32>, i32, vector<12xi32>, i16, vector<4xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<4xf32>
+  %r03_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %cst0, %arg5, %cst3, %arg3, %cst0_i16, %arg1, %cst0, %cst0, %arg8, %cst0, %cst0, %arg8, %zero, %zero :
+                              (i32, vector<16xi32>, i32, vector<12xi32>, i16, vector<4xf32>, i32, i32, i64, i32, i32, i64, i1, i1) -> vector<4xf32>                            
+
+  // fp8 * fp4
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v8i32(i32 {{.*}}, <16 x i32> %{{.*}}, i32 {{.*}}, <8 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v16i32.v8i32(i32 {{.*}}, <16 x i32> %{{.*}}, i32 {{.*}}, <8 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  %r04 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %cst0, %arg5, %cst4, %arg2, %cst0_i16, %arg1, %cst0, %cst0, %arg0, %cst0, %cst0, %arg0, %zero, %zero :
+                              (i32, vector<16xi32>, i32, vector<8xi32>, i16, vector<4xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<4xf32>
+  %r04_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %cst0, %arg5, %cst4, %arg2, %cst0_i16, %arg1, %cst0, %cst0, %arg8, %cst0, %cst0, %arg8, %zero, %zero :
+                              (i32, vector<16xi32>, i32, vector<8xi32>, i16, vector<4xf32>, i32, i32, i64, i32, i32, i64, i1, i1) -> vector<4xf32>                            
+
+  // bf8 * fp8
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 {{.*}}, <16 x i32> %{{.*}}, i32 {{.*}}, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 {{.*}}, <16 x i32> %{{.*}}, i32 {{.*}}, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  %r10 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %cst1, %arg5, %cst0, %arg5, %cst0_i16, %arg1, %cst0, %cst0, %arg0, %cst0, %cst0, %arg0, %zero, %zero :
+                              (i32, vector<16xi32>, i32, vector<16xi32>, i16, vector<4xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<4xf32>
+  %r10_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %cst1, %arg5, %cst0, %arg5, %cst0_i16, %arg1, %cst0, %cst0, %arg8, %cst0, %cst0, %arg8, %zero, %zero :
+                              (i32, vector<16xi32>, i32, vector<16xi32>, i16, vector<4xf32>, i32, i32, i64, i32, i32, i64, i1, i1) -> vector<4xf32>                            
+
+  // bf8 * bf8
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 {{.*}}, <16 x i32> %{{.*}}, i32 {{.*}}, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 {{.*}}, <16 x i32> %{{.*}}, i32 {{.*}}, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  %r11 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %cst1, %arg5, %cst1, %arg5, %cst0_i16, %arg1, %cst0, %cst0, %arg0, %cst0, %cst0, %arg0, %zero, %zero :
+                              (i32, vector<16xi32>, i32, vector<16xi32>, i16, vector<4xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<4xf32>
+  %r11_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %cst1, %arg5, %cst1, %arg5, %cst0_i16, %arg1, %cst0, %cst0, %arg8, %cst0, %cst0, %arg8, %zero, %zero :
+                              (i32, vector<16xi32>, i32, vector<16xi32>, i16, vector<4xf32>, i32, i32, i64, i32, i32, i64, i1, i1) -> vector<4xf32>                            
+
+  // bf8 * fp6
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v12i32(i32 {{.*}}, <16 x i32> %{{.*}}, i32 {{.*}}, <12 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v16i32.v12i32(i32 {{.*}}, <16 x i32> %{{.*}}, i32 {{.*}}, <12 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  %r12 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %cst1, %arg5, %cst2, %arg3, %cst0_i16, %arg1, %cst0, %cst0, %arg0, %cst0, %cst0, %arg0, %zero, %zero :
+                              (i32, vector<16xi32>, i32, vector<12xi32>, i16, vector<4xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<4xf32>
+  %r12_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %cst1, %arg5, %cst2, %arg3, %cst0_i16, %arg1, %cst0, %cst0, %arg8, %cst0, %cst0, %arg8, %zero, %zero :
+                              (i32, vector<16xi32>, i32, vector<12xi32>, i16, vector<4xf32>, i32, i32, i64, i32, i32, i64, i1, i1) -> vector<4xf32>                            
+
+  // bf8 * bf6
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v12i32(i32 {{.*}}, <16 x i32> %{{.*}}, i32 {{.*}}, <12 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v16i32.v12i32(i32 {{.*}}, <16 x i32> %{{.*}}, i32 {{.*}}, <12 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  %r13 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %cst1, %arg5, %cst3, %arg3, %cst0_i16, %arg1, %cst0, %cst0, %arg0, %cst0, %cst0, %arg0, %zero, %zero :
+                              (i32, vector<16xi32>, i32, vector<12xi32>, i16, vector<4xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<4xf32>
+  %r13_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %cst1, %arg5, %cst3, %arg3, %cst0_i16, %arg1, %cst0, %cst0, %arg8, %cst0, %cst0, %arg8, %zero, %zero :
+                              (i32, vector<16xi32>, i32, vector<12xi32>, i16, vector<4xf32>, i32, i32, i64, i32, i32, i64, i1, i1) -> vector<4xf32>                            
+
+  // bf8 * fp4
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v8i32(i32 {{.*}}, <16 x i32> %{{.*}}, i32 {{.*}}, <8 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v16i32.v8i32(i32 {{.*}}, <16 x i32> %{{.*}}, i32 {{.*}}, <8 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  %r14 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %cst1, %arg5, %cst4, %arg2, %cst0_i16, %arg1, %cst0, %cst0, %arg0, %cst0, %cst0, %arg0, %zero, %zero :
+                              (i32, vector<16xi32>, i32, vector<8xi32>, i16, vector<4xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<4xf32>
+  %r14_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %cst1, %arg5, %cst4, %arg2, %cst0_i16, %arg1, %cst0, %cst0, %arg8, %cst0, %cst0, %arg8, %zero, %zero :
+                              (i32, vector<16xi32>, i32, vector<8xi32>, i16, vector<4xf32>, i32, i32, i64, i32, i32, i64, i1, i1) -> vector<4xf32>                            
+
+  // fp6 * fp8
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v12i32.v16i32(i32 {{.*}}, <12 x i32> %{{.*}}, i32 {{.*}}, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v12i32.v16i32(i32 {{.*}}, <12 x i32> %{{.*}}, i32 {{.*}}, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  %r20 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %cst2, %arg3, %cst0, %arg5, %cst0_i16, %arg1, %cst0, %cst0, %arg0, %cst0, %cst0, %arg0, %zero, %zero :
+                              (i32, vector<12xi32>, i32, vector<16xi32>, i16, vector<4xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<4xf32>
+  %r20_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %cst2, %arg3, %cst0, %arg5, %cst0_i16, %arg1, %cst0, %cst0, %arg8, %cst0, %cst0, %arg8, %zero, %zero :
+                              (i32, vector<12xi32>, i32, vector<16xi32>, i16, vector<4xf32>, i32, i32, i64, i32, i32, i64, i1, i1) -> vector<4xf32>                            
+
+  // fp6 * bf8
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v12i32.v16i32(i32 {{.*}}, <12 x i32> %{{.*}}, i32 {{.*}}, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v12i32.v16i32(i32 {{.*}}, <12 x i32> %{{.*}}, i32 {{.*}}, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  %r21 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %cst2, %arg3, %cst1, %arg5, %cst0_i16, %arg1, %cst0, %cst0, %arg0, %cst0, %cst0, %arg0, %zero, %zero :
+                              (i32, vector<12xi32>, i32, vector<16xi32>, i16, vector<4xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<4xf32>
+  %r21_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %cst2, %arg3, %cst1, %arg5, %cst0_i16, %arg1, %cst0, %cst0, %arg8, %cst0, %cst0, %arg8, %zero, %zero :
+                              (i32, vector<12xi32>, i32, vector<16xi32>, i16, vector<4xf32>, i32, i32, i64, i32, i32, i64, i1, i1) -> vector<4xf32>                            
+
+  // fp6 * fp6
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v12i32.v12i32(i32 {{.*}}, <12 x i32> %{{.*}}, i32 {{.*}}, <12 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v12i32.v12i32(i32 {{.*}}, <12 x i32> %{{.*}}, i32 {{.*}}, <12 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  %r22 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %cst2, %arg3, %cst2, %arg3, %cst0_i16, %arg1, %cst0, %cst0, %arg0, %cst0, %cst0, %arg0, %zero, %zero :
+                              (i32, vector<12xi32>, i32, vector<12xi32>, i16, vector<4xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<4xf32>
+  %r22_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %cst2, %arg3, %cst2, %arg3, %cst0_i16, %arg1, %cst0, %cst0, %arg8, %cst0, %cst0, %arg8, %zero, %zero :
+                              (i32, vector<12xi32>, i32, vector<12xi32>, i16, vector<4xf32>, i32, i32, i64, i32, i32, i64, i1, i1) -> vector<4xf32>                            
+
+  // fp6 * bf6
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v12i32.v12i32(i32 {{.*}}, <12 x i32> %{{.*}}, i32 {{.*}}, <12 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v12i32.v12i32(i32 {{.*}}, <12 x i32> %{{.*}}, i32 {{.*}}, <12 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  %r23 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %cst2, %arg3, %cst3, %arg3, %cst0_i16, %arg1, %cst0, %cst0, %arg0, %cst0, %cst0, %arg0, %zero, %zero :
+                              (i32, vector<12xi32>, i32, vector<12xi32>, i16, vector<4xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<4xf32>
+  %r23_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %cst2, %arg3, %cst3, %arg3, %cst0_i16, %arg1, %cst0, %cst0, %arg8, %cst0, %cst0, %arg8, %zero, %zero :
+                              (i32, vector<12xi32>, i32, vector<12xi32>, i16, vector<4xf32>, i32, i32, i64, i32, i32, i64, i1, i1) -> vector<4xf32>                            
+
+  // fp6 * fp4
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v12i32.v8i32(i32 {{.*}}, <12 x i32> %{{.*}}, i32 {{.*}}, <8 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v12i32.v8i32(i32 {{.*}}, <12 x i32> %{{.*}}, i32 {{.*}}, <8 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  %r24 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %cst2, %arg3, %cst4, %arg2, %cst0_i16, %arg1, %cst0, %cst0, %arg0, %cst0, %cst0, %arg0, %zero, %zero :
+                              (i32, vector<12xi32>, i32, vector<8xi32>, i16, vector<4xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<4xf32>
+  %r24_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %cst2, %arg3, %cst4, %arg2, %cst0_i16, %arg1, %cst0, %cst0, %arg8, %cst0, %cst0, %arg8, %zero, %zero :
+                              (i32, vector<12xi32>, i32, vector<8xi32>, i16, vector<4xf32>, i32, i32, i64, i32, i32, i64, i1, i1) -> vector<4xf32>                            
+
+  // bf6 * fp8
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v12i32.v16i32(i32 {{.*}}, <12 x i32> %{{.*}}, i32 {{.*}}, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v12i32.v16i32(i32 {{.*}}, <12 x i32> %{{.*}}, i32 {{.*}}, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  %r30 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %cst3, %arg3, %cst0, %arg5, %cst0_i16, %arg1, %cst0, %cst0, %arg0, %cst0, %cst0, %arg0, %zero, %zero :
+                              (i32, vector<12xi32>, i32, vector<16xi32>, i16, vector<4xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<4xf32>
+  %r30_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %cst3, %arg3, %cst0, %arg5, %cst0_i16, %arg1, %cst0, %cst0, %arg8, %cst0, %cst0, %arg8, %zero, %zero :
+                              (i32, vector<12xi32>, i32, vector<16xi32>, i16, vector<4xf32>, i32, i32, i64, i32, i32, i64, i1, i1) -> vector<4xf32>                            
+
+  // bf6 * bf8
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v12i32.v16i32(i32 {{.*}}, <12 x i32> %{{.*}}, i32 {{.*}}, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v12i32.v16i32(i32 {{.*}}, <12 x i32> %{{.*}}, i32 {{.*}}, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  %r31 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %cst3, %arg3, %cst1, %arg5, %cst0_i16, %arg1, %cst0, %cst0, %arg0, %cst0, %cst0, %arg0, %zero, %zero :
+                              (i32, vector<12xi32>, i32, vector<16xi32>, i16, vector<4xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<4xf32>
+  %r31_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %cst3, %arg3, %cst1, %arg5, %cst0_i16, %arg1, %cst0, %cst0, %arg8, %cst0, %cst0, %arg8, %zero, %zero :
+                              (i32, vector<12xi32>, i32, vector<16xi32>, i16, vector<4xf32>, i32, i32, i64, i32, i32, i64, i1, i1) -> vector<4xf32>                            
+
+  // bf6 * fp6
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v12i32.v12i32(i32 {{.*}}, <12 x i32> %{{.*}}, i32 {{.*}}, <12 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v12i32.v12i32(i32 {{.*}}, <12 x i32> %{{.*}}, i32 {{.*}}, <12 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  %r32 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %cst3, %arg3, %cst2, %arg3, %cst0_i16, %arg1, %cst0, %cst0, %arg0, %cst0, %cst0, %arg0, %zero, %zero :
+                              (i32, vector<12xi32>, i32, vector<12xi32>, i16, vector<4xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<4xf32>
+  %r32_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %cst3, %arg3, %cst2, %arg3, %cst0_i16, %arg1, %cst0, %cst0, %arg8, %cst0, %cst0, %arg8, %zero, %zero :
+                              (i32, vector<12xi32>, i32, vector<12xi32>, i16, vector<4xf32>, i32, i32, i64, i32, i32, i64, i1, i1) -> vector<4xf32>                            
+
+  // bf6 * bf6
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v12i32.v12i32(i32 {{.*}}, <12 x i32> %{{.*}}, i32 {{.*}}, <12 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v12i32.v12i32(i32 {{.*}}, <12 x i32> %{{.*}}, i32 {{.*}}, <12 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  %r33 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %cst3, %arg3, %cst3, %arg3, %cst0_i16, %arg1, %cst0, %cst0, %arg0, %cst0, %cst0, %arg0, %zero, %zero :
+                              (i32, vector<12xi32>, i32, vector<12xi32>, i16, vector<4xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<4xf32>
+  %r33_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %cst3, %arg3, %cst3, %arg3, %cst0_i16, %arg1, %cst0, %cst0, %arg8, %cst0, %cst0, %arg8, %zero, %zero :
+                              (i32, vector<12xi32>, i32, vector<12xi32>, i16, vector<4xf32>, i32, i32, i64, i32, i32, i64, i1, i1) -> vector<4xf32>                            
+
+  // bf6 * fp4
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v12i32.v8i32(i32 {{.*}}, <12 x i32> %{{.*}}, i32 {{.*}}, <8 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v12i32.v8i32(i32 {{.*}}, <12 x i32> %{{.*}}, i32 {{.*}}, <8 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  %r34 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %cst3, %arg3, %cst4, %arg2, %cst0_i16, %arg1, %cst0, %cst0, %arg0, %cst0, %cst0, %arg0, %zero, %zero :
+                              (i32, vector<12xi32>, i32, vector<8xi32>, i16, vector<4xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<4xf32>
+  %r34_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %cst3, %arg3, %cst4, %arg2, %cst0_i16, %arg1, %cst0, %cst0, %arg8, %cst0, %cst0, %arg8, %zero, %zero :
+                              (i32, vector<12xi32>, i32, vector<8xi32>, i16, vector<4xf32>, i32, i32, i64, i32, i32, i64, i1, i1) -> vector<4xf32>                            
+
+  // fp4 * fp8
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v8i32.v16i32(i32 {{.*}}, <8 x i32> %{{.*}}, i32 {{.*}}, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v8i32.v16i32(i32 {{.*}}, <8 x i32> %{{.*}}, i32 {{.*}}, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  %r40 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %cst4, %arg2, %cst0, %arg5, %cst0_i16, %arg1, %cst0, %cst0, %arg0, %cst0, %cst0, %arg0, %zero, %zero :
+                              (i32, vector<8xi32>, i32, vector<16xi32>, i16, vector<4xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<4xf32>
+  %r40_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %cst4, %arg2, %cst0, %arg5, %cst0_i16, %arg1, %cst0, %cst0, %arg8, %cst0, %cst0, %arg8, %zero, %zero :
+                              (i32, vector<8xi32>, i32, vector<16xi32>, i16, vector<4xf32>, i32, i32, i64, i32, i32, i64, i1, i1) -> vector<4xf32>                            
+
+  // fp4 * bf8
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v8i32.v16i32(i32 {{.*}}, <8 x i32> %{{.*}}, i32 {{.*}}, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v8i32.v16i32(i32 {{.*}}, <8 x i32> %{{.*}}, i32 {{.*}}, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  %r41 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %cst4, %arg2, %cst1, %arg5, %cst0_i16, %arg1, %cst0, %cst0, %arg0, %cst0, %cst0, %arg0, %zero, %zero :
+                              (i32, vector<8xi32>, i32, vector<16xi32>, i16, vector<4xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<4xf32>
+  %r41_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %cst4, %arg2, %cst1, %arg5, %cst0_i16, %arg1, %cst0, %cst0, %arg8, %cst0, %cst0, %arg8, %zero, %zero :
+                              (i32, vector<8xi32>, i32, vector<16xi32>, i16, vector<4xf32>, i32, i32, i64, i32, i32, i64, i1, i1) -> vector<4xf32>                            
+
+  // fp4 * fp6
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v8i32.v12i32(i32 {{.*}}, <8 x i32> %{{.*}}, i32 {{.*}}, <12 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v8i32.v12i32(i32 {{.*}}, <8 x i32> %{{.*}}, i32 {{.*}}, <12 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  %r42 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %cst4, %arg2, %cst2, %arg3, %cst0_i16, %arg1, %cst0, %cst0, %arg0, %cst0, %cst0, %arg0, %zero, %zero :
+                              (i32, vector<8xi32>, i32, vector<12xi32>, i16, vector<4xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<4xf32>
+  %r42_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %cst4, %arg2, %cst2, %arg3, %cst0_i16, %arg1, %cst0, %cst0, %arg8, %cst0, %cst0, %arg8, %zero, %zero :
+                              (i32, vector<8xi32>, i32, vector<12xi32>, i16, vector<4xf32>, i32, i32, i64, i32, i32, i64, i1, i1) -> vector<4xf32>                            
+
+  // fp4 * bf6
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v8i32.v12i32(i32 {{.*}}, <8 x i32> %{{.*}}, i32 {{.*}}, <12 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v8i32.v12i32(i32 {{.*}}, <8 x i32> %{{.*}}, i32 {{.*}}, <12 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  %r43 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %cst4, %arg2, %cst3, %arg3, %cst0_i16, %arg1, %cst0, %cst0, %arg0, %cst0, %cst0, %arg0, %zero, %zero :
+                              (i32, vector<8xi32>, i32, vector<12xi32>, i16, vector<4xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<4xf32>
+  %r43_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %cst4, %arg2, %cst3, %arg3, %cst0_i16, %arg1, %cst0, %cst0, %arg8, %cst0, %cst0, %arg8, %zero, %zero :
+                              (i32, vector<8xi32>, i32, vector<12xi32>, i16, vector<4xf32>, i32, i32, i64, i32, i32, i64, i1, i1) -> vector<4xf32>                            
+
+  // fp4 * fp4
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v8i32.v8i32(i32 {{.*}}, <8 x i32> %{{.*}}, i32 {{.*}}, <8 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v8i32.v8i32(i32 {{.*}}, <8 x i32> %{{.*}}, i32 {{.*}}, <8 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
+  %r44 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %cst4, %arg2, %cst4, %arg2, %cst0_i16, %arg1, %cst0, %cst0, %arg0, %cst0, %cst0, %arg0, %zero, %zero :
+                              (i32, vector<8xi32>, i32, vector<8xi32>, i16, vector<4xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<4xf32>
+  %r44_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %cst4, %arg2, %cst4, %arg2, %cst0_i16, %arg1, %cst0, %cst0, %arg8, %cst0, %cst0, %arg8, %zero, %zero :
+                              (i32, vector<8xi32>, i32, vector<8xi32>, i16, vector<4xf32>, i32, i32, i64, i32, i32, i64, i1, i1) -> vector<4xf32>                            
+
+  llvm.return %r00 : vector<4 x f32>
+}
+
 llvm.func @rocdl.load.to.lds(%src : !llvm.ptr<7>, %dst: !llvm.ptr<3>) {
   //CHECK: call void @llvm.amdgcn.load.to.lds.p7
   rocdl.load.to.lds %src, %dst, 4, 0, 0 : !llvm.ptr<7>

>From 3633a56e4403a51fdb7500e1138ead0b4438d909 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Fri, 31 Oct 2025 18:24:06 -0500
Subject: [PATCH 2/2] adding wmma scaled intrinsics with matrix A size 32x128

Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
 mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td |  2 ++
 mlir/test/Target/LLVMIR/rocdl.mlir           | 19 +++++++++++++------
 2 files changed, 15 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index f14fb42402913..e172d68fec3e3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -629,6 +629,8 @@ def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8
 def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x64.iu8", [1]>;
 def ROCDL_wmma_scale_f32_16x16x128_f8f6f4   : ROCDL_Wmma_IntrOp<"wmma.scale.f32.16x16x128.f8f6f4", [1,3]>;
 def ROCDL_wmma_scale16_f32_16x16x128_f8f6f4   : ROCDL_Wmma_IntrOp<"wmma.scale16.f32.16x16x128.f8f6f4", [1,3]>;
+def ROCDL_wmma_scale_f32_32x16x128_f4   : ROCDL_Wmma_IntrOp<"wmma.scale.f32.32x16x128.f4", [0,1]>;
+def ROCDL_wmma_scale16_f32_32x16x128_f4   : ROCDL_Wmma_IntrOp<"wmma.scale16.f32.32x16x128.f4", [0,1]>;
 
 //===---------------------------------------------------------------------===//
 // LDS transpose intrinsics (available in GFX950)
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 3c985a819ba0b..df0fe1531cd62 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1028,11 +1028,12 @@ llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
   llvm.return %r3 : vector<4xf16>
 }
 
-llvm.func @rocdl.wmma.scale.f32.16x16x128.f8f6f4(%arg0 : i32,
-                   %arg1 : vector<4 x f32>, %arg2 : vector<8xi32>,
-                   %arg3 : vector<12xi32>, %arg4 : vector<4xi32>, 
-                   %arg5 : vector<16xi32>, %arg6 : vector<12xi32>,
-                   %arg7 : vector<8xi64>, %arg8 : i64) -> vector<4 x f32> {
+llvm.func @rocdl.wmma.scale.f32.16x16x128.f8f6f4(
+                   %arg0 : i32, %arg1 : vector<4xf32>,
+                   %arg2 : vector<8xi32>, %arg3 : vector<12xi32>,
+                   %arg4 : vector<4xi32>, %arg5 : vector<16xi32>,
+                   %arg6 : vector<12xi32>,%arg7 : vector<8xi64>,
+                   %arg8 : i64, %arg9 : vector<8xf32>) -> vector<4xf32> {
   %cst0 = llvm.mlir.constant(0 : i32) : i32
   %cst1 = llvm.mlir.constant(1 : i32) : i32
   %cst2 = llvm.mlir.constant(2 : i32) : i32
@@ -1042,7 +1043,6 @@ llvm.func @rocdl.wmma.scale.f32.16x16x128.f8f6f4(%arg0 : i32,
   %zero = llvm.mlir.constant(false) : i1
   // CHECK-LABEL: rocdl.wmma.scale.f32.16x16x128.f8f6f4
   
-  
   // fp8 * fp8
   // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 {{.*}}, <16 x i32> %{{.*}}, i32 {{.*}}, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
   // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 {{.*}}, <16 x i32> %{{.*}}, i32 {{.*}}, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 {{.*}}%{{.*}}, i1 {{.*}}, i1 {{.*}})
@@ -1243,6 +1243,13 @@ llvm.func @rocdl.wmma.scale.f32.16x16x128.f8f6f4(%arg0 : i32,
   %r44_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %cst4, %arg2, %cst4, %arg2, %cst0_i16, %arg1, %cst0, %cst0, %arg8, %cst0, %cst0, %arg8, %zero, %zero :
                               (i32, vector<8xi32>, i32, vector<8xi32>, i16, vector<4xf32>, i32, i32, i64, i32, i32, i64, i1, i1) -> vector<4xf32>                            
 
+  // CHECK: call <8 x float> @llvm.amdgcn.wmma.scale.f32.32x16x128.f4.v8f32.v16i32.v8i32(<16 x i32> %{{.*}}, <8 x i32> %{{.*}}, i16 {{.*}}, <8 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 %{{.*}}, i1 {{.*}}, i1 {{.*}})
+  // CHECK: call <8 x float> @llvm.amdgcn.wmma.scale16.f32.32x16x128.f4.v8f32.v16i32.v8i32(<16 x i32> %{{.*}}, <8 x i32> %{{.*}}, i16 {{.*}}, <8 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 %{{.*}}, i32 {{.*}}, i32 {{.*}}, i64 %{{.*}}, i1 {{.*}}, i1 {{.*}})
+  %r45 = rocdl.wmma.scale.f32.32x16x128.f4 %arg5, %arg2, %cst0_i16, %arg9, %cst0, %cst0, %arg0, %cst0, %cst0, %arg0, %zero, %zero :
+                              (vector<16xi32>, vector<8xi32>, i16, vector<8xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<8xf32>
+  %r45_scale16 = rocdl.wmma.scale16.f32.32x16x128.f4 %arg5, %arg2, %cst0_i16, %arg9, %cst0, %cst0, %arg8, %cst0, %cst0, %arg8, %zero, %zero :
+                              (vector<16xi32>, vector<8xi32>, i16, vector<8xf32>, i32, i32, i64, i32, i32, i64, i1, i1) -> vector<8xf32>                            
+
   llvm.return %r00 : vector<4 x f32>
 }
 



More information about the Mlir-commits mailing list