[Mlir-commits] [mlir] [mlir][ROCDL] Adds wmma scaled intrinsics for gfx1250 (PR #165915)

Muzammiluddin Syed llvmlistbot at llvm.org
Wed Nov 19 11:26:29 PST 2025


https://github.com/Muzammiluddin-Syed-ECE updated https://github.com/llvm/llvm-project/pull/165915

>From bbc5776603636bd8a590629dc8db935c430252fe Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Wed, 19 Nov 2025 12:22:26 -0600
Subject: [PATCH 1/2] [ROCDL] Add scaled WMMA intrinsics for gfx1250

This patch adds support for scaled WMMA intrinsics available on gfx1250+.

Key changes:
- Created two classes for intrinsic variants ROCDL_WMMA_Scale_IntrOp and ROCDL_WMMA_Scale_F4_IntrOp.
- Tests in rocdl.mlir updated.

Intrinsics added:
- rocdl.wmma.scale.f32.16x16x128.f8f6f4
- rocdl.wmma.scale16.f32.16x16x128.f8f6f4
- rocdl.wmma.scale.f32.32x16x128.f4
- rocdl.wmma.scale16.f32.32x16x128.f4

Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
 mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td |  56 +++++++
 mlir/test/Target/LLVMIR/rocdl.mlir           | 153 ++++++++++++++++++-
 2 files changed, 208 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 19741f10ce8cc..7323ef60e0670 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -701,6 +701,56 @@ class ROCDL_WMMA_ModsAB_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp
   }];
 }
 
+// Overloaded operands: [1, 3] refers to LLVM intrinsic parameter positions where
+// A is at position 1 and B is at position 3 (after format parameters).
+class ROCDL_WMMA_Scale_IntrOp<string mnemonic, Type AB, Type CD, Type ScaleExpTy> : ROCDL_IntrOp<mnemonic,
+    [0], [1, 3], [], 1, 0, 0, 0, [0, 2, 4, 6, 7, 9, 10, 12, 13], 
+    ["matrix_a_fmt", "matrix_b_fmt", "modC", "matrix_a_scale", "matrix_a_scale_fmt", 
+     "matrix_b_scale", "matrix_b_scale_fmt", "reuseA", "reuseB"]>,
+  Arguments<(ins 
+             DefaultValuedAttr<I32Attr, "0">:$matrix_a_fmt,
+             LLVM_ScalarOrVectorOf<AB>:$A, 
+             DefaultValuedAttr<I32Attr, "0">:$matrix_b_fmt,
+             LLVM_ScalarOrVectorOf<AB>:$B,
+             DefaultValuedAttr<I16Attr, "0">:$modC, 
+             LLVM_ScalarOrVectorOf<CD>:$C,
+             DefaultValuedAttr<I32Attr, "0">:$matrix_a_scale,
+             DefaultValuedAttr<I32Attr, "0">:$matrix_a_scale_fmt,
+             ScaleExpTy:$matrix_a_scale_exp,
+             DefaultValuedAttr<I32Attr, "0">:$matrix_b_scale,
+             DefaultValuedAttr<I32Attr, "0">:$matrix_b_scale_fmt,
+             ScaleExpTy:$matrix_b_scale_exp,
+             DefaultValuedAttr<I1Attr, "0">:$reuseA,
+             DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
+  let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+  let assemblyFormat = [{
+    $A `,` $B `,` $C `,` $matrix_a_scale_exp `,` $matrix_b_scale_exp attr-dict `:` functional-type(operands, $res)
+  }];
+}
+
+class ROCDL_WMMA_Scale_F4_IntrOp<string mnemonic, Type AB, Type CD, Type ScaleExpTy> : ROCDL_IntrOp<mnemonic,
+    [0], [0, 1], [], 1, 0, 0, 0, [2, 4, 5, 7, 8, 10, 11], 
+    ["modC", "matrix_a_scale", "matrix_a_scale_fmt", 
+     "matrix_b_scale", "matrix_b_scale_fmt", "reuseA", "reuseB"]>,
+  Arguments<(ins 
+             LLVM_ScalarOrVectorOf<AB>:$A, 
+             LLVM_ScalarOrVectorOf<AB>:$B,
+             DefaultValuedAttr<I16Attr, "0">:$modC, 
+             LLVM_ScalarOrVectorOf<CD>:$C,
+             DefaultValuedAttr<I32Attr, "0">:$matrix_a_scale,
+             DefaultValuedAttr<I32Attr, "0">:$matrix_a_scale_fmt,
+             ScaleExpTy:$matrix_a_scale_exp,
+             DefaultValuedAttr<I32Attr, "0">:$matrix_b_scale,
+             DefaultValuedAttr<I32Attr, "0">:$matrix_b_scale_fmt,
+             ScaleExpTy:$matrix_b_scale_exp,
+             DefaultValuedAttr<I1Attr, "0">:$reuseA,
+             DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
+  let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+  let assemblyFormat = [{
+    $A `,` $B `,` $C `,` $matrix_a_scale_exp `,` $matrix_b_scale_exp attr-dict `:` functional-type(operands, $res)
+  }];
+}
+
 // Available from gfx11
 def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.f16", /*Type AB=*/F16, /*Type CD=*/F32>;
 def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.bf16", AnyInteger, F32>;
@@ -739,6 +789,12 @@ def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x1
 def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.bf8_bf8", AnyInteger, F16>;
 def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_WMMA_ModsAB_IntrOp<"wmma.i32.16x16x64.iu8", AnyInteger, AnyInteger>;
 
+// Scaled wmma intrinsics (available from gfx1250)
+def ROCDL_wmma_scale_f32_16x16x128_f8f6f4   : ROCDL_WMMA_Scale_IntrOp<"wmma.scale.f32.16x16x128.f8f6f4", AnyInteger, F32, I32>;
+def ROCDL_wmma_scale16_f32_16x16x128_f8f6f4 : ROCDL_WMMA_Scale_IntrOp<"wmma.scale16.f32.16x16x128.f8f6f4", AnyInteger, F32, I64>;
+def ROCDL_wmma_scale_f32_32x16x128_f4       : ROCDL_WMMA_Scale_F4_IntrOp<"wmma.scale.f32.32x16x128.f4", AnyInteger, F32, I32>;
+def ROCDL_wmma_scale16_f32_32x16x128_f4     : ROCDL_WMMA_Scale_F4_IntrOp<"wmma.scale16.f32.32x16x128.f4", AnyInteger, F32, I64>;
+
 //===---------------------------------------------------------------------===//
 // LDS transpose intrinsics (available in GFX950)
 
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index dcf80ad4395de..f0932f268bd56 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -918,6 +918,30 @@ llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : v
   // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x32.iu4.v8i32.v2i32(i1 false, <2 x i32> %{{.*}} i1 false, <2 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false)
   %r6.gfx12 = rocdl.wmma.i32.16x16x32.iu4 %arg4, %arg4, %arg3 {signA = false, signB = false, clamp = false} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32>
 
+  // Test signA=true, signB=false for iu8
+  // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v8i32.v4i32(i1 true, <4 x i32> %{{.*}} i1 false, <4 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false)
+  %r5a = rocdl.wmma.i32.16x16x16.iu8 %arg5, %arg5, %arg3 {signA = true, signB = false, clamp = false} : (vector<4xi32>, vector<4xi32>, vector<8xi32>) -> vector<8xi32>
+
+  // Test signA=false, signB=true for iu8
+  // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v8i32.v4i32(i1 false, <4 x i32> %{{.*}} i1 true, <4 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false)
+  %r5b = rocdl.wmma.i32.16x16x16.iu8 %arg5, %arg5, %arg3 {signA = false, signB = true, clamp = false} : (vector<4xi32>, vector<4xi32>, vector<8xi32>) -> vector<8xi32>
+
+  // Test signA=true, signB=true, clamp=true for iu8
+  // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v8i32.v4i32(i1 true, <4 x i32> %{{.*}} i1 true, <4 x i32> %{{.*}} <8 x i32> %{{.*}} i1 true)
+  %r5c = rocdl.wmma.i32.16x16x16.iu8 %arg5, %arg5, %arg3 {signA = true, signB = true, clamp = true} : (vector<4xi32>, vector<4xi32>, vector<8xi32>) -> vector<8xi32>
+
+  // Test signA=true, signB=false for iu4
+  // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v8i32.v2i32(i1 true, <2 x i32> %{{.*}} i1 false, <2 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false)
+  %r6a = rocdl.wmma.i32.16x16x16.iu4 %arg4, %arg4, %arg3 {signA = true, signB = false, clamp = false} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32>
+
+  // Test signA=false, signB=true, clamp=true for iu4
+  // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v8i32.v2i32(i1 false, <2 x i32> %{{.*}} i1 true, <2 x i32> %{{.*}} <8 x i32> %{{.*}} i1 true)
+  %r6b = rocdl.wmma.i32.16x16x16.iu4 %arg4, %arg4, %arg3 {signA = false, signB = true, clamp = true} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32>
+
+  // Test signA=true, signB=true for iu4 gfx12
+  // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x32.iu4.v8i32.v2i32(i1 true, <2 x i32> %{{.*}} i1 true, <2 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false)
+  %r6c = rocdl.wmma.i32.16x16x32.iu4 %arg4, %arg4, %arg3 {signA = true, signB = true, clamp = false} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32>
+
   // f32 -> f32
   // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x4.f32.v4f32.v16f32(i1 false, <16 x float> %{{.*}} i1 false, <16 x float> %{{.*}} i16 0, <4 x float> %{{.*}} i1 false, i1 false)
   %r1.gfx1250 = rocdl.wmma.f32.16x16x4.f32 %arg10, %arg10, %arg11 {signA = false, signB = false, modC = 0 : i16} : (vector<16xf32>, vector<16xf32>, vector<4xf32>) -> vector<4xf32>
@@ -981,7 +1005,7 @@ llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : v
 
   // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.fp8.fp8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false)
   %r19.gfx1250 = rocdl.wmma.f16.16x16x128.fp8_fp8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16>
-  
+
   // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.fp8.bf8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false)
   %r20.gfx1250 = rocdl.wmma.f16.16x16x128.fp8_bf8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16>
 
@@ -995,6 +1019,26 @@ llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : v
   // CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 false, <4 x i32> %{{.*}} i1 false, <4 x i32> %{{.*}} <64 x i32> %{{.*}} i1 false, i1 false)
   %r23.gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %arg5, %arg5, %arg14 {signA = false, signB = false} : (vector<4xi32>, vector<4xi32>, vector<64xi32>) -> vector<64xi32>
 
+  // Test signA=true, signB=true for iu8 gfx1250  
+  // CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 true, <4 x i32> %{{.*}} i1 true, <4 x i32> %{{.*}} <64 x i32> %{{.*}} i1 false, i1 false)
+  %r23a.gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %arg5, %arg5, %arg14 {signA = true, signB = true} : (vector<4xi32>, vector<4xi32>, vector<64xi32>) -> vector<64xi32>
+
+  // Test signA=true, signB=false, reuseA=true, reuseB=true for iu8 gfx1250
+  // CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 true, <4 x i32> %{{.*}} i1 false, <4 x i32> %{{.*}} <64 x i32> %{{.*}} i1 true, i1 true)
+  %r23b.gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %arg5, %arg5, %arg14 {signA = true, signB = false, reuseA = true, reuseB = true} : (vector<4xi32>, vector<4xi32>, vector<64xi32>) -> vector<64xi32>
+
+  // Test signA=true, signB=true with modC=1 for f32 gfx1250
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x4.f32.v4f32.v16f32(i1 true, <16 x float> %{{.*}} i1 true, <16 x float> %{{.*}} i16 1, <4 x float> %{{.*}} i1 false, i1 false)
+  %r1a.gfx1250 = rocdl.wmma.f32.16x16x4.f32 %arg10, %arg10, %arg11 {signA = true, signB = true, modC = 1 : i16, reuseA = false, reuseB = false} : (vector<16xf32>, vector<16xf32>, vector<4xf32>) -> vector<4xf32>
+
+  // Test with modC=2 and signA=false, signB=true, reuseA=true for f16 gfx1250
+  // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v32f32.v16f16(i1 false, <16 x half> %{{.*}} i1 true, <16 x half> %{{.*}} i16 2, <32 x float> %{{.*}} i1 true, i1 false)
+  %r2a.gfx1250 = rocdl.wmma.f32.16x16x32.f16 %arg1, %arg1, %arg12 {signA = false, signB = true, modC = 2 : i16, reuseA = true, reuseB = false} : (vector<16xf16>, vector<16xf16>, vector<32xf32>) -> vector<32xf32>
+
+  // Test with modC=3 and signA=true, signB=true, reuseB=true for bf16 gfx1250
+  // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.bf16.v32f32.v16bf16(i1 true, <16 x bfloat> %{{.*}} i1 true, <16 x bfloat> %{{.*}} i16 3, <32 x float> %{{.*}} i1 false, i1 true)
+  %r3a.gfx1250 = rocdl.wmma.f32.16x16x32.bf16 %arg16, %arg16, %arg12 {signA = true, signB = true, modC = 3 : i16, reuseA = false, reuseB = true} : (vector<16xbf16>, vector<16xbf16>, vector<32xf32>) -> vector<32xf32>
+
   // ---- Wave64 -----
 
   // f16 -> f32
@@ -1231,6 +1275,113 @@ llvm.func @rocdl.raw.ptr.buffer.load.lds(%rsrc : !llvm.ptr<8>, %dstLds : !llvm.p
   llvm.return
 }
 
+llvm.func @rocdl.wmma.scale(%arg0: i32, %arg1: vector<4xf32>, %arg2: vector<8xi32>,
+                            %arg3: vector<12xi32>, %arg5: vector<16xi32>,
+                            %arg8: i64, %arg9: vector<8xf32>) -> vector<4xf32> {
+  // CHECK-LABEL: rocdl.wmma.scale
+  
+  // Test with default attributes (all zeros/false)
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 0, <16 x i32> %{{.*}}, i32 0, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i1 false, i1 false)
+  %r00 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg0, %arg0
+    {matrix_a_fmt = 0 : i32, matrix_b_fmt = 0 : i32, modC = 0 : i16,
+     matrix_a_scale = 0 : i32, matrix_a_scale_fmt = 0 : i32,
+     matrix_b_scale = 0 : i32, matrix_b_scale_fmt = 0 : i32,
+     reuseA = false, reuseB = false} :
+    (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+  
+  // Test with different matrix formats (FP8 x BF8)
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 0, <16 x i32> %{{.*}}, i32 1, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i1 false, i1 false)
+  %r01 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg0, %arg0
+    {matrix_a_fmt = 0 : i32, matrix_b_fmt = 1 : i32, modC = 0 : i16,
+     matrix_a_scale = 1 : i32, matrix_a_scale_fmt = 1 : i32,
+     matrix_b_scale = 1 : i32, matrix_b_scale_fmt = 1 : i32,
+     reuseA = false, reuseB = false} :
+    (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+  
+  // Test with FP8 x FP6 (different vector sizes) and modC = 1 (negate)
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v12i32(i32 0, <16 x i32> %{{.*}}, i32 2, <12 x i32> %{{.*}}, i16 1, <4 x float> %{{.*}}, i32 2, i32 2, i32 %{{.*}}, i32 2, i32 2, i32 %{{.*}}, i1 false, i1 false)
+  %r02 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg3, %arg1, %arg0, %arg0
+    {matrix_a_fmt = 0 : i32, matrix_b_fmt = 2 : i32, modC = 1 : i16,
+     matrix_a_scale = 2 : i32, matrix_a_scale_fmt = 2 : i32,
+     matrix_b_scale = 2 : i32, matrix_b_scale_fmt = 2 : i32,
+     reuseA = false, reuseB = false} :
+    (vector<16xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+  
+  // Test with BF8 x BF6 and modC = 2 (abs)
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v12i32(i32 1, <16 x i32> %{{.*}}, i32 3, <12 x i32> %{{.*}}, i16 2, <4 x float> %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i1 false, i1 false)
+  %r03 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg3, %arg1, %arg0, %arg0
+    {matrix_a_fmt = 1 : i32, matrix_b_fmt = 3 : i32, modC = 2 : i16,
+     matrix_a_scale = 0 : i32, matrix_a_scale_fmt = 0 : i32,
+     matrix_b_scale = 0 : i32, matrix_b_scale_fmt = 0 : i32,
+     reuseA = false, reuseB = false} :
+    (vector<16xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+  
+  // Test with FP8 x FP4 and modC = 3 (negate(abs))
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v8i32(i32 0, <16 x i32> %{{.*}}, i32 4, <8 x i32> %{{.*}}, i16 3, <4 x float> %{{.*}}, i32 3, i32 3, i32 %{{.*}}, i32 3, i32 3, i32 %{{.*}}, i1 false, i1 false)
+  %r04 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg2, %arg1, %arg0, %arg0
+    {matrix_a_fmt = 0 : i32, matrix_b_fmt = 4 : i32, modC = 3 : i16,
+     matrix_a_scale = 3 : i32, matrix_a_scale_fmt = 3 : i32,
+     matrix_b_scale = 3 : i32, matrix_b_scale_fmt = 3 : i32,
+     reuseA = false, reuseB = false} :
+    (vector<16xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+  
+  // Test with reuseA = true
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 2, <16 x i32> %{{.*}}, i32 2, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i1 true, i1 false)
+  %r10 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg0, %arg0
+    {matrix_a_fmt = 2 : i32, matrix_b_fmt = 2 : i32, modC = 0 : i16,
+     matrix_a_scale = 0 : i32, matrix_a_scale_fmt = 0 : i32,
+     matrix_b_scale = 0 : i32, matrix_b_scale_fmt = 0 : i32,
+     reuseA = true, reuseB = false} :
+    (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+  
+  // Test with reuseB = true
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 3, <16 x i32> %{{.*}}, i32 3, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i1 false, i1 true)
+  %r11 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg0, %arg0
+    {matrix_a_fmt = 3 : i32, matrix_b_fmt = 3 : i32, modC = 0 : i16,
+     matrix_a_scale = 0 : i32, matrix_a_scale_fmt = 0 : i32,
+     matrix_b_scale = 0 : i32, matrix_b_scale_fmt = 0 : i32,
+     reuseA = false, reuseB = true} :
+    (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+  
+  // Test with both reuseA and reuseB = true
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 4, <16 x i32> %{{.*}}, i32 4, <16 x i32> %{{.*}}, i16 1, <4 x float> %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i1 true, i1 true)
+  %r12 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg0, %arg0
+    {matrix_a_fmt = 4 : i32, matrix_b_fmt = 4 : i32, modC = 1 : i16,
+     matrix_a_scale = 1 : i32, matrix_a_scale_fmt = 1 : i32,
+     matrix_b_scale = 1 : i32, matrix_b_scale_fmt = 1 : i32,
+     reuseA = true, reuseB = true} :
+    (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+  
+  // Test scale16 variant with i64 scale exponents
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 0, <16 x i32> %{{.*}}, i32 1, <16 x i32> %{{.*}}, i16 2, <4 x float> %{{.*}}, i32 2, i32 2, i64 %{{.*}}, i32 2, i32 2, i64 %{{.*}}, i1 false, i1 false)
+  %r_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg8, %arg8
+    {matrix_a_fmt = 0 : i32, matrix_b_fmt = 1 : i32, modC = 2 : i16,
+     matrix_a_scale = 2 : i32, matrix_a_scale_fmt = 2 : i32,
+     matrix_b_scale = 2 : i32, matrix_b_scale_fmt = 2 : i32,
+     reuseA = false, reuseB = false} :
+    (vector<16xi32>, vector<16xi32>, vector<4xf32>, i64, i64) -> vector<4xf32>
+  
+  // Test f4 variant (no matrix format parameters)
+  // CHECK: call <8 x float> @llvm.amdgcn.wmma.scale.f32.32x16x128.f4.v8f32.v16i32.v8i32(<16 x i32> %{{.*}}, <8 x i32> %{{.*}}, i16 0, <8 x float> %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i1 false, i1 false)
+  %r_f4 = rocdl.wmma.scale.f32.32x16x128.f4 %arg5, %arg2, %arg9, %arg0, %arg0
+    {modC = 0 : i16,
+     matrix_a_scale = 1 : i32, matrix_a_scale_fmt = 1 : i32,
+     matrix_b_scale = 1 : i32, matrix_b_scale_fmt = 1 : i32,
+     reuseA = false, reuseB = false} :
+    (vector<16xi32>, vector<8xi32>, vector<8xf32>, i32, i32) -> vector<8xf32>
+  
+  // Test f4 scale16 variant with varied attributes
+  // CHECK: call <8 x float> @llvm.amdgcn.wmma.scale16.f32.32x16x128.f4.v8f32.v16i32.v8i32(<16 x i32> %{{.*}}, <8 x i32> %{{.*}}, i16 3, <8 x float> %{{.*}}, i32 2, i32 3, i64 %{{.*}}, i32 3, i32 2, i64 %{{.*}}, i1 true, i1 true)
+  %r_f4_scale16 = rocdl.wmma.scale16.f32.32x16x128.f4 %arg5, %arg2, %arg9, %arg8, %arg8
+    {modC = 3 : i16,
+     matrix_a_scale = 2 : i32, matrix_a_scale_fmt = 3 : i32,
+     matrix_b_scale = 3 : i32, matrix_b_scale_fmt = 2 : i32,
+     reuseA = true, reuseB = true} :
+    (vector<16xi32>, vector<8xi32>, vector<8xf32>, i64, i64) -> vector<8xf32>
+  
+  llvm.return %r00 : vector<4xf32>
+}
+
 llvm.func @rocdl.raw.ptr.buffer.atomic.f32(%rsrc : !llvm.ptr<8>,
                         %offset : i32, %soffset : i32,
                         %vdata1 : f32) {

>From 3526591070e2e34338f1e632dcfae6901c659dfa Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Wed, 19 Nov 2025 13:26:12 -0600
Subject: [PATCH 2/2] renaming attributes

Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
 mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 40 ++++++-------
 mlir/test/Target/LLVMIR/rocdl.mlir           | 62 ++++++++++----------
 2 files changed, 51 insertions(+), 51 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 7323ef60e0670..abc6fc6c32d61 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -705,49 +705,49 @@ class ROCDL_WMMA_ModsAB_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp
 // A is at position 1 and B is at position 3 (after format parameters).
 class ROCDL_WMMA_Scale_IntrOp<string mnemonic, Type AB, Type CD, Type ScaleExpTy> : ROCDL_IntrOp<mnemonic,
     [0], [1, 3], [], 1, 0, 0, 0, [0, 2, 4, 6, 7, 9, 10, 12, 13], 
-    ["matrix_a_fmt", "matrix_b_fmt", "modC", "matrix_a_scale", "matrix_a_scale_fmt", 
-     "matrix_b_scale", "matrix_b_scale_fmt", "reuseA", "reuseB"]>,
+    ["fmtA", "fmtB", "modC", "scaleAType", "fmtScaleA", 
+     "scaleBType", "fmtScaleB", "reuseA", "reuseB"]>,
   Arguments<(ins 
-             DefaultValuedAttr<I32Attr, "0">:$matrix_a_fmt,
+             DefaultValuedAttr<I32Attr, "0">:$fmtA,
              LLVM_ScalarOrVectorOf<AB>:$A, 
-             DefaultValuedAttr<I32Attr, "0">:$matrix_b_fmt,
+             DefaultValuedAttr<I32Attr, "0">:$fmtB,
              LLVM_ScalarOrVectorOf<AB>:$B,
              DefaultValuedAttr<I16Attr, "0">:$modC, 
              LLVM_ScalarOrVectorOf<CD>:$C,
-             DefaultValuedAttr<I32Attr, "0">:$matrix_a_scale,
-             DefaultValuedAttr<I32Attr, "0">:$matrix_a_scale_fmt,
-             ScaleExpTy:$matrix_a_scale_exp,
-             DefaultValuedAttr<I32Attr, "0">:$matrix_b_scale,
-             DefaultValuedAttr<I32Attr, "0">:$matrix_b_scale_fmt,
-             ScaleExpTy:$matrix_b_scale_exp,
+             DefaultValuedAttr<I32Attr, "0">:$scaleAType,
+             DefaultValuedAttr<I32Attr, "0">:$fmtScaleA,
+             ScaleExpTy:$scaleA,
+             DefaultValuedAttr<I32Attr, "0">:$scaleBType,
+             DefaultValuedAttr<I32Attr, "0">:$fmtScaleB,
+             ScaleExpTy:$scaleB,
              DefaultValuedAttr<I1Attr, "0">:$reuseA,
              DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
   let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
   let assemblyFormat = [{
-    $A `,` $B `,` $C `,` $matrix_a_scale_exp `,` $matrix_b_scale_exp attr-dict `:` functional-type(operands, $res)
+    $A `,` $B `,` $C `,` $scaleA `,` $scaleB attr-dict `:` functional-type(operands, $res)
   }];
 }
 
 class ROCDL_WMMA_Scale_F4_IntrOp<string mnemonic, Type AB, Type CD, Type ScaleExpTy> : ROCDL_IntrOp<mnemonic,
     [0], [0, 1], [], 1, 0, 0, 0, [2, 4, 5, 7, 8, 10, 11], 
-    ["modC", "matrix_a_scale", "matrix_a_scale_fmt", 
-     "matrix_b_scale", "matrix_b_scale_fmt", "reuseA", "reuseB"]>,
+    ["modC", "scaleAType", "fmtScaleA", 
+     "scaleBType", "fmtScaleB", "reuseA", "reuseB"]>,
   Arguments<(ins 
              LLVM_ScalarOrVectorOf<AB>:$A, 
              LLVM_ScalarOrVectorOf<AB>:$B,
              DefaultValuedAttr<I16Attr, "0">:$modC, 
              LLVM_ScalarOrVectorOf<CD>:$C,
-             DefaultValuedAttr<I32Attr, "0">:$matrix_a_scale,
-             DefaultValuedAttr<I32Attr, "0">:$matrix_a_scale_fmt,
-             ScaleExpTy:$matrix_a_scale_exp,
-             DefaultValuedAttr<I32Attr, "0">:$matrix_b_scale,
-             DefaultValuedAttr<I32Attr, "0">:$matrix_b_scale_fmt,
-             ScaleExpTy:$matrix_b_scale_exp,
+             DefaultValuedAttr<I32Attr, "0">:$scaleAType,
+             DefaultValuedAttr<I32Attr, "0">:$fmtScaleA,
+             ScaleExpTy:$scaleA,
+             DefaultValuedAttr<I32Attr, "0">:$scaleBType,
+             DefaultValuedAttr<I32Attr, "0">:$fmtScaleB,
+             ScaleExpTy:$scaleB,
              DefaultValuedAttr<I1Attr, "0">:$reuseA,
              DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
   let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
   let assemblyFormat = [{
-    $A `,` $B `,` $C `,` $matrix_a_scale_exp `,` $matrix_b_scale_exp attr-dict `:` functional-type(operands, $res)
+    $A `,` $B `,` $C `,` $scaleA `,` $scaleB attr-dict `:` functional-type(operands, $res)
   }];
 }
 
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index f0932f268bd56..86b69812787b8 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1283,81 +1283,81 @@ llvm.func @rocdl.wmma.scale(%arg0: i32, %arg1: vector<4xf32>, %arg2: vector<8xi3
   // Test with default attributes (all zeros/false)
   // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 0, <16 x i32> %{{.*}}, i32 0, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i1 false, i1 false)
   %r00 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg0, %arg0
-    {matrix_a_fmt = 0 : i32, matrix_b_fmt = 0 : i32, modC = 0 : i16,
-     matrix_a_scale = 0 : i32, matrix_a_scale_fmt = 0 : i32,
-     matrix_b_scale = 0 : i32, matrix_b_scale_fmt = 0 : i32,
+    {fmtA = 0 : i32, fmtB = 0 : i32, modC = 0 : i16,
+     scaleAType = 0 : i32, fmtScaleA = 0 : i32,
+     scaleBType = 0 : i32, fmtScaleB = 0 : i32,
      reuseA = false, reuseB = false} :
     (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
   
   // Test with different matrix formats (FP8 x BF8)
   // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 0, <16 x i32> %{{.*}}, i32 1, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i1 false, i1 false)
   %r01 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg0, %arg0
-    {matrix_a_fmt = 0 : i32, matrix_b_fmt = 1 : i32, modC = 0 : i16,
-     matrix_a_scale = 1 : i32, matrix_a_scale_fmt = 1 : i32,
-     matrix_b_scale = 1 : i32, matrix_b_scale_fmt = 1 : i32,
+    {fmtA = 0 : i32, fmtB = 1 : i32, modC = 0 : i16,
+     scaleAType = 1 : i32, fmtScaleA = 1 : i32,
+     scaleBType = 1 : i32, fmtScaleB = 1 : i32,
      reuseA = false, reuseB = false} :
     (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
   
   // Test with FP8 x FP6 (different vector sizes) and modC = 1 (negate)
   // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v12i32(i32 0, <16 x i32> %{{.*}}, i32 2, <12 x i32> %{{.*}}, i16 1, <4 x float> %{{.*}}, i32 2, i32 2, i32 %{{.*}}, i32 2, i32 2, i32 %{{.*}}, i1 false, i1 false)
   %r02 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg3, %arg1, %arg0, %arg0
-    {matrix_a_fmt = 0 : i32, matrix_b_fmt = 2 : i32, modC = 1 : i16,
-     matrix_a_scale = 2 : i32, matrix_a_scale_fmt = 2 : i32,
-     matrix_b_scale = 2 : i32, matrix_b_scale_fmt = 2 : i32,
+    {fmtA = 0 : i32, fmtB = 2 : i32, modC = 1 : i16,
+     scaleAType = 2 : i32, fmtScaleA = 2 : i32,
+     scaleBType = 2 : i32, fmtScaleB = 2 : i32,
      reuseA = false, reuseB = false} :
     (vector<16xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
   
   // Test with BF8 x BF6 and modC = 2 (abs)
   // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v12i32(i32 1, <16 x i32> %{{.*}}, i32 3, <12 x i32> %{{.*}}, i16 2, <4 x float> %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i1 false, i1 false)
   %r03 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg3, %arg1, %arg0, %arg0
-    {matrix_a_fmt = 1 : i32, matrix_b_fmt = 3 : i32, modC = 2 : i16,
-     matrix_a_scale = 0 : i32, matrix_a_scale_fmt = 0 : i32,
-     matrix_b_scale = 0 : i32, matrix_b_scale_fmt = 0 : i32,
+    {fmtA = 1 : i32, fmtB = 3 : i32, modC = 2 : i16,
+     scaleAType = 0 : i32, fmtScaleA = 0 : i32,
+     scaleBType = 0 : i32, fmtScaleB = 0 : i32,
      reuseA = false, reuseB = false} :
     (vector<16xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
   
   // Test with FP8 x FP4 and modC = 3 (negate(abs))
   // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v8i32(i32 0, <16 x i32> %{{.*}}, i32 4, <8 x i32> %{{.*}}, i16 3, <4 x float> %{{.*}}, i32 3, i32 3, i32 %{{.*}}, i32 3, i32 3, i32 %{{.*}}, i1 false, i1 false)
   %r04 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg2, %arg1, %arg0, %arg0
-    {matrix_a_fmt = 0 : i32, matrix_b_fmt = 4 : i32, modC = 3 : i16,
-     matrix_a_scale = 3 : i32, matrix_a_scale_fmt = 3 : i32,
-     matrix_b_scale = 3 : i32, matrix_b_scale_fmt = 3 : i32,
+    {fmtA = 0 : i32, fmtB = 4 : i32, modC = 3 : i16,
+     scaleAType = 3 : i32, fmtScaleA = 3 : i32,
+     scaleBType = 3 : i32, fmtScaleB = 3 : i32,
      reuseA = false, reuseB = false} :
     (vector<16xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
   
   // Test with reuseA = true
   // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 2, <16 x i32> %{{.*}}, i32 2, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i1 true, i1 false)
   %r10 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg0, %arg0
-    {matrix_a_fmt = 2 : i32, matrix_b_fmt = 2 : i32, modC = 0 : i16,
-     matrix_a_scale = 0 : i32, matrix_a_scale_fmt = 0 : i32,
-     matrix_b_scale = 0 : i32, matrix_b_scale_fmt = 0 : i32,
+    {fmtA = 2 : i32, fmtB = 2 : i32, modC = 0 : i16,
+     scaleAType = 0 : i32, fmtScaleA = 0 : i32,
+     scaleBType = 0 : i32, fmtScaleB = 0 : i32,
      reuseA = true, reuseB = false} :
     (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
   
   // Test with reuseB = true
   // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 3, <16 x i32> %{{.*}}, i32 3, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i1 false, i1 true)
   %r11 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg0, %arg0
-    {matrix_a_fmt = 3 : i32, matrix_b_fmt = 3 : i32, modC = 0 : i16,
-     matrix_a_scale = 0 : i32, matrix_a_scale_fmt = 0 : i32,
-     matrix_b_scale = 0 : i32, matrix_b_scale_fmt = 0 : i32,
+    {fmtA = 3 : i32, fmtB = 3 : i32, modC = 0 : i16,
+     scaleAType = 0 : i32, fmtScaleA = 0 : i32,
+     scaleBType = 0 : i32, fmtScaleB = 0 : i32,
      reuseA = false, reuseB = true} :
     (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
   
   // Test with both reuseA and reuseB = true
   // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 4, <16 x i32> %{{.*}}, i32 4, <16 x i32> %{{.*}}, i16 1, <4 x float> %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i1 true, i1 true)
   %r12 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg0, %arg0
-    {matrix_a_fmt = 4 : i32, matrix_b_fmt = 4 : i32, modC = 1 : i16,
-     matrix_a_scale = 1 : i32, matrix_a_scale_fmt = 1 : i32,
-     matrix_b_scale = 1 : i32, matrix_b_scale_fmt = 1 : i32,
+    {fmtA = 4 : i32, fmtB = 4 : i32, modC = 1 : i16,
+     scaleAType = 1 : i32, fmtScaleA = 1 : i32,
+     scaleBType = 1 : i32, fmtScaleB = 1 : i32,
      reuseA = true, reuseB = true} :
     (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
   
   // Test scale16 variant with i64 scale exponents
   // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 0, <16 x i32> %{{.*}}, i32 1, <16 x i32> %{{.*}}, i16 2, <4 x float> %{{.*}}, i32 2, i32 2, i64 %{{.*}}, i32 2, i32 2, i64 %{{.*}}, i1 false, i1 false)
   %r_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg8, %arg8
-    {matrix_a_fmt = 0 : i32, matrix_b_fmt = 1 : i32, modC = 2 : i16,
-     matrix_a_scale = 2 : i32, matrix_a_scale_fmt = 2 : i32,
-     matrix_b_scale = 2 : i32, matrix_b_scale_fmt = 2 : i32,
+    {fmtA = 0 : i32, fmtB = 1 : i32, modC = 2 : i16,
+     scaleAType = 2 : i32, fmtScaleA = 2 : i32,
+     scaleBType = 2 : i32, fmtScaleB = 2 : i32,
      reuseA = false, reuseB = false} :
     (vector<16xi32>, vector<16xi32>, vector<4xf32>, i64, i64) -> vector<4xf32>
   
@@ -1365,8 +1365,8 @@ llvm.func @rocdl.wmma.scale(%arg0: i32, %arg1: vector<4xf32>, %arg2: vector<8xi3
   // CHECK: call <8 x float> @llvm.amdgcn.wmma.scale.f32.32x16x128.f4.v8f32.v16i32.v8i32(<16 x i32> %{{.*}}, <8 x i32> %{{.*}}, i16 0, <8 x float> %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i1 false, i1 false)
   %r_f4 = rocdl.wmma.scale.f32.32x16x128.f4 %arg5, %arg2, %arg9, %arg0, %arg0
     {modC = 0 : i16,
-     matrix_a_scale = 1 : i32, matrix_a_scale_fmt = 1 : i32,
-     matrix_b_scale = 1 : i32, matrix_b_scale_fmt = 1 : i32,
+     scaleAType = 1 : i32, fmtScaleA = 1 : i32,
+     scaleBType = 1 : i32, fmtScaleB = 1 : i32,
      reuseA = false, reuseB = false} :
     (vector<16xi32>, vector<8xi32>, vector<8xf32>, i32, i32) -> vector<8xf32>
   
@@ -1374,8 +1374,8 @@ llvm.func @rocdl.wmma.scale(%arg0: i32, %arg1: vector<4xf32>, %arg2: vector<8xi3
   // CHECK: call <8 x float> @llvm.amdgcn.wmma.scale16.f32.32x16x128.f4.v8f32.v16i32.v8i32(<16 x i32> %{{.*}}, <8 x i32> %{{.*}}, i16 3, <8 x float> %{{.*}}, i32 2, i32 3, i64 %{{.*}}, i32 3, i32 2, i64 %{{.*}}, i1 true, i1 true)
   %r_f4_scale16 = rocdl.wmma.scale16.f32.32x16x128.f4 %arg5, %arg2, %arg9, %arg8, %arg8
     {modC = 3 : i16,
-     matrix_a_scale = 2 : i32, matrix_a_scale_fmt = 3 : i32,
-     matrix_b_scale = 3 : i32, matrix_b_scale_fmt = 2 : i32,
+     scaleAType = 2 : i32, fmtScaleA = 3 : i32,
+     scaleBType = 3 : i32, fmtScaleB = 2 : i32,
      reuseA = true, reuseB = true} :
     (vector<16xi32>, vector<8xi32>, vector<8xf32>, i64, i64) -> vector<8xf32>
   



More information about the Mlir-commits mailing list