[Mlir-commits] [mlir] [mlir][ROCDL] Refactor wmma intrinsics to use attributes not operands where possible (PR #167041)

Muzammiluddin Syed llvmlistbot at llvm.org
Mon Nov 10 10:41:23 PST 2025


https://github.com/Muzammiluddin-Syed-ECE updated https://github.com/llvm/llvm-project/pull/167041

>From fc0d7fbd628950bc00f5a169d49e2fac7ba250ca Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Tue, 7 Oct 2025 13:28:29 -0500
Subject: [PATCH 1/4] [mlir][ROCDL] refactor wmma intrinsics to use attributes
 instead of operands where possible

Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
 mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td  | 177 ++++++++++++++----
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           |  87 +++++----
 .../Conversion/AMDGPUToROCDL/wmma-gfx11.mlir  |  16 +-
 .../Conversion/AMDGPUToROCDL/wmma-gfx12.mlir  |  20 +-
 .../AMDGPUToROCDL/wmma-gfx1250.mlir           |  22 +--
 mlir/test/Target/LLVMIR/rocdl.mlir            | 143 +++++++-------
 6 files changed, 286 insertions(+), 179 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 921fdf36a59b0..282f30d68269f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -582,51 +582,148 @@ def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.f
 
 //===---------------------------------------------------------------------===//
 // WMMA intrinsics
-class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands,
-                        list<Trait> traits = []> :
-  ROCDL_IntrOp<mnemonic, [0], overloadedOperands, traits, 1>,
-  Arguments<(ins Variadic<LLVM_Type>:$args)> {
-  let assemblyFormat =
-    "$args attr-dict `:` functional-type($args, $res)";
+class ROCDL_WMMA_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
+    [0], [0], [], 1, 0, 0, 0, [], []>,
+  Arguments<(ins 
+             LLVM_ScalarOrVectorOf<AB>:$A, 
+             LLVM_ScalarOrVectorOf<AB>:$B,
+             LLVM_ScalarOrVectorOf<CD>:$C)> {
+  let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+  let assemblyFormat = [{
+    $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+  }];
+}
+
+class ROCDL_WMMA_Opsel_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
+    [0], [1], [], 1, 0, 0, 0, [3], ["opsel"]>,
+  Arguments<(ins 
+             LLVM_ScalarOrVectorOf<AB>:$A,
+             LLVM_ScalarOrVectorOf<AB>:$B,
+             LLVM_ScalarOrVectorOf<CD>:$C, 
+             DefaultValuedAttr<I1Attr, "0">:$opsel)> {
+  let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+  let assemblyFormat = [{
+    $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+  }];
+}
+
+class ROCDL_WMMA_IU_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
+    [0], [1], [], 1, 0, 0, 0, [0, 2, 5], ["signA", "signB", "clamp"]>,
+  Arguments<(ins 
+             DefaultValuedAttr<I1Attr, "0">:$signA,
+             LLVM_ScalarOrVectorOf<AB>:$A, 
+             DefaultValuedAttr<I1Attr, "0">:$signB,
+             LLVM_ScalarOrVectorOf<AB>:$B,
+             LLVM_ScalarOrVectorOf<CD>:$C, 
+             DefaultValuedAttr<I1Attr, "0">:$clamp)> {
+  let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+  let assemblyFormat = [{
+    $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+  }];
+}
+
+class ROCDL_WMMA_ModsAll_Reuse_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
+    [0], [1], [], 1, 0, 0, 0, [0, 2, 4, 6, 7], ["signA", "signB","modC","reuseA","reuseB"]>,
+  Arguments<(ins 
+             DefaultValuedAttr<I1Attr, "0">:$signA,
+             LLVM_ScalarOrVectorOf<AB>:$A, 
+             DefaultValuedAttr<I1Attr, "0">:$signB,
+             LLVM_ScalarOrVectorOf<AB>:$B,
+             DefaultValuedAttr<I16Attr, "0">:$modC, 
+             LLVM_ScalarOrVectorOf<CD>:$C, 
+             DefaultValuedAttr<I1Attr, "0">:$reuseA, 
+             DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
+  let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+  let assemblyFormat = [{
+    $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+  }];
+}
+
+class ROCDL_WMMA_ModsC_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
+    [0], [0], [], 1, 0, 0, 0, [2, 4, 5], ["modC","reuseA","reuseB"]>,
+  Arguments<(ins 
+             LLVM_ScalarOrVectorOf<AB>:$A, 
+             LLVM_ScalarOrVectorOf<AB>:$B,
+             DefaultValuedAttr<I16Attr, "0">:$modC, 
+             LLVM_ScalarOrVectorOf<CD>:$C, 
+             DefaultValuedAttr<I1Attr, "0">:$reuseA, 
+             DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
+  let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+  let assemblyFormat = [{
+    $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+  }];
+}
+
+class ROCDL_WMMA_ModsAll_Diff_IntrOp<string mnemonic, Type AB, Type C, Type D> : ROCDL_IntrOp<mnemonic,
+    [0], [1, 5], [], 1, 0, 0, 0, [0, 2, 4, 6, 7], ["signA", "signB","modC","reuseA","reuseB"]>,
+  Arguments<(ins 
+             DefaultValuedAttr<I1Attr, "0">:$signA,
+             LLVM_ScalarOrVectorOf<AB>:$A, 
+             DefaultValuedAttr<I1Attr, "0">:$signB,
+             LLVM_ScalarOrVectorOf<AB>:$B,
+             DefaultValuedAttr<I16Attr, "0">:$modC, 
+             LLVM_ScalarOrVectorOf<C>:$C, 
+             DefaultValuedAttr<I1Attr, "0">:$reuseA, 
+             DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
+  let results = (outs LLVM_ScalarOrVectorOf<D>:$res);
+  let assemblyFormat = [{
+    $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+  }];
+}
+
+class ROCDL_WMMA_ModsAB_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
+    [0], [1], [], 1, 0, 0, 0, [0, 2, 5, 6], ["signA", "signB", "reuseA","reuseB"]>,
+  Arguments<(ins 
+             DefaultValuedAttr<I1Attr, "0">:$signA,
+             LLVM_ScalarOrVectorOf<AB>:$A, 
+             DefaultValuedAttr<I1Attr, "0">:$signB,
+             LLVM_ScalarOrVectorOf<AB>:$B,
+             LLVM_ScalarOrVectorOf<CD>:$C, 
+             DefaultValuedAttr<I1Attr, "0">:$reuseA, 
+             DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
+  let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+  let assemblyFormat = [{
+    $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+  }];
 }
 
 // Available from gfx11
-def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.f16", [0]>;
-def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf16", [0]>;
-def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x16.f16", [0]>;
-def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x16.bf16", [0]>;
-def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu8", [1]>;
-def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1]>;
+def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.f16", /*Type AB=*/F16, /*Type CD=*/F32>;
+def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.bf16", AnyInteger, F32>;
+def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_WMMA_Opsel_IntrOp<"wmma.f16.16x16x16.f16", F16, F16>;
+def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_WMMA_Opsel_IntrOp<"wmma.bf16.16x16x16.bf16", AnyInteger, AnyInteger>;
+def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_WMMA_IU_IntrOp<"wmma.i32.16x16x16.iu8", AnyInteger, AnyInteger>;
+def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_WMMA_IU_IntrOp<"wmma.i32.16x16x16.iu4", AnyInteger, AnyInteger>;
 // Available from gfx12
-def ROCDL_wmma_f32_16x16x16_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>;
-def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_bf8", [1]>;
-def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
-def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_fp8", [1]>;
-def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x32.iu4", [1]>;
+def ROCDL_wmma_f32_16x16x16_fp8_fp8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.fp8_fp8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.fp8_bf8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.bf8_bf8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.bf8_fp8", AnyInteger, F32>;
+def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_WMMA_IU_IntrOp<"wmma.i32.16x16x32.iu4", AnyInteger, AnyInteger>;
 // Available from gfx1250
-def ROCDL_wmma_f32_16x16x4_f32 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x4.f32", [1]>;
-def ROCDL_wmma_f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.bf16", [1]>;
-def ROCDL_wmma_f32_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.f16", [1]>;
-def ROCDL_wmma_f16_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x32.f16", [1]>;
-def ROCDL_wmma_bf16_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x32.bf16", [1]>;
-def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16f32.16x16x32.bf16", [1,5]>;
-def ROCDL_wmma_f32_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_fp8", [0]>;
-def ROCDL_wmma_f32_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_bf8", [0]>;
-def ROCDL_wmma_f32_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_fp8", [0]>;
-def ROCDL_wmma_f32_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_bf8", [0]>;
-def ROCDL_wmma_f16_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_fp8", [0]>;
-def ROCDL_wmma_f16_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_bf8", [0]>;
-def ROCDL_wmma_f16_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_fp8", [0]>;
-def ROCDL_wmma_f16_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_bf8", [0]>;
-def ROCDL_wmma_f32_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_fp8", [0]>;
-def ROCDL_wmma_f32_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_bf8", [0]>;
-def ROCDL_wmma_f32_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_fp8", [0]>;
-def ROCDL_wmma_f32_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_bf8", [0]>;
-def ROCDL_wmma_f16_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_fp8", [0]>;
-def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_bf8", [0]>;
-def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_fp8", [0]>;
-def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_bf8", [0]>;
-def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x64.iu8", [1]>;
+def ROCDL_wmma_f32_16x16x4_f32 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f32.16x16x4.f32", F32, F32>;
+def ROCDL_wmma_f32_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f32.16x16x32.bf16", BF16, F32>;
+def ROCDL_wmma_f32_16x16x32_f16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f32.16x16x32.f16", F16, F32>;
+def ROCDL_wmma_f16_16x16x32_f16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f16.16x16x32.f16", F16, F16>;
+def ROCDL_wmma_bf16_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.bf16.16x16x32.bf16", BF16, BF16>;
+def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Diff_IntrOp<"wmma.bf16f32.16x16x32.bf16", BF16, /*Type C=*/F32, /*Type D=*/BF16>;
+def ROCDL_wmma_f32_16x16x64_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.fp8_fp8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x64_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.fp8_bf8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x64_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.bf8_fp8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x64_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.bf8_bf8", AnyInteger, F32>;
+def ROCDL_wmma_f16_16x16x64_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.fp8_fp8", AnyInteger, F16>;
+def ROCDL_wmma_f16_16x16x64_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.fp8_bf8", AnyInteger, F16>;
+def ROCDL_wmma_f16_16x16x64_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.bf8_fp8", AnyInteger, F16>;
+def ROCDL_wmma_f16_16x16x64_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.bf8_bf8", AnyInteger, F16>;
+def ROCDL_wmma_f32_16x16x128_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.fp8_fp8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x128_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.fp8_bf8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x128_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.bf8_fp8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x128_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.bf8_bf8", AnyInteger, F32>;
+def ROCDL_wmma_f16_16x16x128_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.fp8_fp8", AnyInteger, F16>;
+def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.fp8_bf8", AnyInteger, F16>;
+def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.bf8_fp8", AnyInteger, F16>;
+def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.bf8_bf8", AnyInteger, F16>;
+def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_WMMA_ModsAB_IntrOp<"wmma.i32.16x16x64.iu8", AnyInteger, AnyInteger>;
 
 //===---------------------------------------------------------------------===//
 // LDS transpose intrinsics (available in GFX950)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 3a307a0756d93..e3425fe703b83 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -79,12 +80,6 @@ static Value createI64Constant(ConversionPatternRewriter &rewriter,
   return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
 }
 
-static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
-                              bool value) {
-  Type llvmI1 = rewriter.getI1Type();
-  return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value);
-}
-
 /// Returns the linear index used to access an element in the memref.
 static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
                                Location loc, MemRefDescriptor &memRefDescriptor,
@@ -684,12 +679,11 @@ static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
 /// intrinsics having been defined before the AMD backend supported bfloat. We
 /// similarly need to pack 8-bit float types into integers as if they were i8
 /// (which they are for the backend's purposes).
-static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
-                                 Location loc,
-                                 const TypeConverter *typeConverter,
-                                 bool isUnsigned, Value llvmInput,
-                                 Value mlirInput,
-                                 SmallVector<Value, 4> &operands) {
+static void wmmaPushInputOperand(
+    ConversionPatternRewriter &rewriter, Location loc,
+    const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput,
+    Value mlirInput, SmallVector<Value, 4> &operands,
+    SmallVector<NamedAttribute, 4> &attrs, StringRef attrName) {
   Type inputType = llvmInput.getType();
   auto vectorType = dyn_cast<VectorType>(inputType);
   if (!vectorType) {
@@ -697,10 +691,6 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
     return;
   }
   Type elemType = vectorType.getElementType();
-
-  if (elemType.isBF16())
-    llvmInput = LLVM::BitcastOp::create(
-        rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
   if (elemType.getIntOrFloatBitWidth() > 8) {
     operands.push_back(llvmInput);
     return;
@@ -719,8 +709,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
     } else if (elemType.isSignedInteger()) {
       localIsUnsigned = false;
     }
-    Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
-    operands.push_back(sign);
+    attrs.push_back(NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
   }
 
   int64_t numBits =
@@ -751,18 +740,17 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
                                   Location loc,
                                   const TypeConverter *typeConverter,
                                   Value output, int32_t subwordOffset,
-                                  bool clamp, SmallVector<Value, 4> &operands) {
+                                  bool clamp, SmallVector<Value, 4> &operands,
+                                  SmallVector<NamedAttribute, 4> &attrs) {
   Type inputType = output.getType();
   auto vectorType = dyn_cast<VectorType>(inputType);
   Type elemType = vectorType.getElementType();
-  if (elemType.isBF16())
-    output = LLVM::BitcastOp::create(
-        rewriter, loc, vectorType.clone(rewriter.getI16Type()), output);
   operands.push_back(output);
   if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
-    operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
+    attrs.push_back(
+        NamedAttribute("opsel", rewriter.getBoolAttr(subwordOffset)));
   } else if (elemType.isInteger(32)) {
-    operands.push_back(createI1Constant(rewriter, loc, clamp));
+    attrs.push_back(NamedAttribute("clamp", rewriter.getBoolAttr(clamp)));
   }
 }
 
@@ -1302,6 +1290,7 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   LogicalResult
   matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+
     Location loc = op.getLoc();
     auto outType =
         typeConverter->convertType<VectorType>(op.getDestD().getType());
@@ -1311,34 +1300,56 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
     if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
       return op->emitOpError("WMMA only supported on gfx11 and gfx12");
 
-    // The WMMA operations represent vectors of bf16s as vectors of i16s, so we
-    // need to bitcast bfloats to i16 and then bitcast them back.
+    bool isGFX1250 = chipset == Chipset(12, 5, 0);
+
+    // The WMMA operations represent vectors of bf16s as vectors of i16s
+    // (except on gfx1250), so we need to bitcast bfloats to i16 and then
+    // bitcast them back.
+    auto aType = dyn_cast<VectorType>(adaptor.getSourceA().getType());
+    auto bType = dyn_cast<VectorType>(adaptor.getSourceA().getType());
+    auto destCType = dyn_cast<VectorType>(adaptor.getDestC().getType());
+    bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
+    bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
+    bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
+    bool castOutToI16 = outType.getElementType().isBF16() && !isGFX1250;
     VectorType rawOutType = outType;
-    if (outType.getElementType().isBF16())
+    if (castOutToI16)
       rawOutType = outType.clone(rewriter.getI16Type());
+    Value a = adaptor.getSourceA();
+    if (castAToI16)
+      a = LLVM::BitcastOp::create(rewriter, loc,
+                                  aType.clone(rewriter.getI16Type()), a);
+    Value b = adaptor.getSourceB();
+    if (castBToI16)
+      b = LLVM::BitcastOp::create(rewriter, loc,
+                                  bType.clone(rewriter.getI16Type()), b);
+    Value destC = adaptor.getDestC();
+    if (castDestCToI16)
+      destC = LLVM::BitcastOp::create(
+          rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
 
     std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
-
     if (!maybeIntrinsic.has_value())
       return op.emitOpError("no intrinsic matching WMMA on the given chipset");
 
     if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
       return op.emitOpError("subwordOffset not supported on gfx12+");
 
-    OperationState loweredOp(loc, *maybeIntrinsic);
-    loweredOp.addTypes(rawOutType);
-
     SmallVector<Value, 4> operands;
-    wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
-                         adaptor.getSourceA(), op.getSourceA(), operands);
-    wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
-                         adaptor.getSourceB(), op.getSourceB(), operands);
-    wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
-                          op.getSubwordOffset(), op.getClamp(), operands);
+    SmallVector<NamedAttribute, 4> attrs;
+    wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), a,
+                         op.getSourceA(), operands, attrs, "signA");
+    wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), b,
+                         op.getSourceB(), operands, attrs, "signB");
+    wmmaPushOutputOperand(rewriter, loc, typeConverter, destC,
+                          op.getSubwordOffset(), op.getClamp(), operands,
+                          attrs);
 
+    OperationState loweredOp(loc, *maybeIntrinsic);
+    loweredOp.addTypes(rawOutType);
     loweredOp.addOperands(operands);
+    loweredOp.addAttributes(attrs);
     Operation *lowered = rewriter.create(loweredOp);
-
     Operation *maybeCastBack = lowered;
     if (rawOutType != outType)
       maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
index 9fcc1473d4a18..6f3ecf33b48a9 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
@@ -13,23 +13,23 @@ func.func @wmma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 :
   amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
   // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<4xf32>) -> vector<4xf32>
   amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32>
-  // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
+  // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<16xf16>) -> vector<16xf16>
   amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
-  // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16>
+  // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>) -> vector<8xf16>
   amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
-  // CHECK: %[[raw_bf16x16:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16>
+  // CHECK: %[[raw_bf16x16:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>) -> vector<16xi16>
   // CHECK-NEXT: llvm.bitcast %[[raw_bf16x16]] : vector<16xi16> to vector<16xbf16>
   amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
-  // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16>
+  // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>) -> vector<8xi16>
   // CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16>
   amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
-  // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32>
+  // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (vector<4xi32>, vector<4xi32>, vector<8xi32>) -> vector<8xi32>
   amdgpu.wmma 16x16x16 %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32>
-  // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32>
+  // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (vector<4xi32>, vector<4xi32>, vector<4xi32>) -> vector<4xi32>
   amdgpu.wmma 16x16x16 %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32>
-  // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
+  // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32>
   amdgpu.wmma 16x16x16 %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32>
-  // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
+  // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i32, i32, vector<4xi32>) -> vector<4xi32>
   amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32>
 
   return
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
index 57883473bbf06..d4bff19da418e 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
@@ -20,15 +20,15 @@ func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>,
   // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xf32>) -> vector<4xf32>
   amdgpu.wmma 16x16x16 %arg5 * %arg5 + %arg3 : vector<4xbf16>, vector<4xbf16>, vector<4xf32>
 
-  // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<8xf16>, i1) -> vector<8xf16>
+  // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<8xf16>) -> vector<8xf16>
   amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg0 : vector<8xf16>, vector<8xf16>, vector<8xf16>
-  // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf16>, i1) -> vector<4xf16>
+  // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf16>) -> vector<4xf16>
   amdgpu.wmma 16x16x16 %arg1 * %arg1 + %arg1 : vector<4xf16>, vector<4xf16>, vector<4xf16>
 
-  // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<8xi16>, i1) -> vector<8xi16>
+  // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<8xi16>) -> vector<8xi16>
   // CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16>
   amdgpu.wmma 16x16x16 %arg4 * %arg4 + %arg4 : vector<8xbf16>, vector<8xbf16>, vector<8xbf16>
-  // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xi16>, i1) -> vector<4xi16>
+  // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xi16>) -> vector<4xi16>
   amdgpu.wmma 16x16x16 %arg5 * %arg5 + %arg5 : vector<4xbf16>, vector<4xbf16>, vector<4xbf16>
 
   // CHECK: rocdl.wmma.f32.16x16x16.fp8_fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
@@ -51,19 +51,19 @@ func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>,
   // CHECK: rocdl.wmma.f32.16x16x16.bf8_fp8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32>
   amdgpu.wmma 16x16x16 %arg9 * %arg7 + %arg3 : vector<4xf8E5M2>, vector<4xf8E4M3FN>, vector<4xf32>
 
-  // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
+  // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32>
   amdgpu.wmma 16x16x16 %arg10 * %arg10 + %arg12 {clamp} : vector<8xi8>, vector<8xi8>, vector<8xi32>
-  // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
+  // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i32, i32, vector<4xi32>) -> vector<4xi32>
   amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg13 {unsignedA, unsignedB, clamp}: vector<4xi8>, vector<4xi8>, vector<4xi32>
 
-  // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
+  // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32>
   amdgpu.wmma 16x16x32 %arg14 * %arg14 + %arg12 {clamp} : vector<16xi4>, vector<16xi4>, vector<8xi32>
-  // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
+  // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (i32, i32, vector<4xi32>) -> vector<4xi32>
   amdgpu.wmma 16x16x32 %arg15 * %arg15 + %arg13 {clamp} : vector<8xi4>, vector<8xi4>, vector<4xi32>
 
-  // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<8xi32>, i1) -> vector<8xi32>
+  // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i32, i32, vector<8xi32>) -> vector<8xi32>
   amdgpu.wmma 16x16x16 %arg15 * %arg15 + %arg12 {clamp} : vector<8xi4>, vector<8xi4>, vector<8xi32>
-  // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
+  // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i32, i32, vector<4xi32>) -> vector<4xi32>
   amdgpu.wmma 16x16x16 %arg16 * %arg16 + %arg13 {clamp} : vector<4xi4>, vector<4xi4>, vector<4xi32>
 
   func.return
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
index 5e77a3add3184..a1be0323ee5a8 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
@@ -14,13 +14,13 @@ func.func @wmma_k32(%arg0 : vector<16xf16>, %arg1 : vector<16xbf16>, %arg2 : vec
   // CHECK: rocdl.wmma.f32.16x16x32.f16 %arg0, %arg0, %arg2
   amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<8xf32>
 
-  // CHECK: rocdl.wmma.f16.16x16x32.f16 %arg0, %arg0, {{.*}} : (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1)
+  // CHECK: rocdl.wmma.f16.16x16x32.f16 %arg0, %arg0, {{.*}} : (vector<16xf16>, vector<16xf16>, vector<8xf16>)
   amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg3 : vector<16xf16>, vector<16xf16>, vector<8xf16>
 
   // CHECK: rocdl.wmma.f32.16x16x32.bf16 {{.*}}, {{.*}}, %arg2
   amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
 
-  // CHECK: rocdl.wmma.bf16.16x16x32.bf16 {{.*}}, {{.*}}, {{.*}}, {{.*}} : (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1)
+  // CHECK: rocdl.wmma.bf16.16x16x32.bf16 {{.*}}, {{.*}}, {{.*}} : (vector<16xbf16>, vector<16xbf16>, vector<8xbf16>)
   amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg4 : vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
 
   return
@@ -29,31 +29,31 @@ func.func @wmma_k32(%arg0 : vector<16xf16>, %arg1 : vector<16xbf16>, %arg2 : vec
 // CHECK-LABEL: @wmma_k64
 func.func @wmma_k64(%arg0 : vector<32xi8>, %arg1 : vector<32xf8E4M3FN>, %arg2 : vector<32xf8E5M2>,
                     %arg3 : vector<8xi32>, %arg4 : vector<8xf32>, %arg5 : vector<8xf16>) {
-  // CHECK: rocdl.wmma.i32.16x16x64.iu8 {{.*}}, {{.*}}, {{.*}}, {{.*}}, %arg3, {{.*}}
+  // CHECK: rocdl.wmma.i32.16x16x64.iu8 {{.*}}, {{.*}}, %arg3 {{.*}}
   amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg3 {clamp} : vector<32xi8>, vector<32xi8>, vector<8xi32>
 
   // CHECK: rocdl.wmma.f32.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg4
   amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf32>
 
-  // CHECK: rocdl.wmma.f16.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+  // CHECK: rocdl.wmma.f16.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg5 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>)
   amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf16>
 
   // CHECK: rocdl.wmma.f32.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg4
   amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf32>
 
-  // CHECK: rocdl.wmma.f16.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+  // CHECK: rocdl.wmma.f16.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg5 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>)
   amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf16>
 
   // CHECK: rocdl.wmma.f32.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg4
   amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg4 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf32>
 
-  // CHECK: rocdl.wmma.f16.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+  // CHECK: rocdl.wmma.f16.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg5 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>)
   amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg5 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf16>
 
   // CHECK: rocdl.wmma.f32.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg4
   amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg4 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf32>
 
-  // CHECK: rocdl.wmma.f16.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+  // CHECK: rocdl.wmma.f16.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg5 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>)
   amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg5 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf16>
 
   return
@@ -65,25 +65,25 @@ func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
   // CHECK: rocdl.wmma.f32.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg2
   amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf32>
 
-  // CHECK: rocdl.wmma.f16.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+  // CHECK: rocdl.wmma.f16.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg3 {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>)
   amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf16>
 
   // CHECK: rocdl.wmma.f32.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg2
   amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf32>
 
-  // CHECK: rocdl.wmma.f16.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+  // CHECK: rocdl.wmma.f16.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg3 {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>)
   amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf16>
 
   // CHECK: rocdl.wmma.f32.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg2
   amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg2 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf32>
 
-  // CHECK: rocdl.wmma.f16.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+  // CHECK: rocdl.wmma.f16.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg3 {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>)
   amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg3 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf16>
 
   // CHECK: rocdl.wmma.f32.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg2
   amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg2 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf32>
 
-  // CHECK: rocdl.wmma.f16.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+  // CHECK: rocdl.wmma.f16.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg3 {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>)
   amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg3 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf16>
 
   return
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 3fbd9e0567948..727e7e189e544 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -875,140 +875,139 @@ llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : v
                       %arg4 : vector<2xi32>, %arg5 : vector<4xi32>, %arg6 : vector<4xf32>, %arg7 : vector<8xf16>, %arg8 : vector<8xi16>,
                       %arg9 : vector<32xf16>, %arg10 : vector<16xf32>, %arg11 : vector<4xf32>, %arg12 : vector<32xf32>, %arg13 : vector<64xf32>, 
                       %arg14 : vector<64xi32>, %arg15 : vector<64xf16>, %arg16 : vector<16xbf16>, %arg17 : vector<32xbf16>) -> vector<8xf32> {
-  %zero = llvm.mlir.constant(false) : i1
-  %zero_i16 = llvm.mlir.constant(0 : i16) : i16
-  // ---- Wave32 -----
 
+  // ---- Wave32 -----
+  
   // f16 -> f32
-  // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16.v8f32.v16f16(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <8 x float> %{{.*}})
+  // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16.v8f32.v16f16(<16 x half> %{{.*}} <16 x half> %{{.*}} <8 x float> %{{.*}})
   %r0 = rocdl.wmma.f32.16x16x16.f16 %arg1, %arg1, %arg0 : (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
 
   // bf16 -> f32
-  // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16.v8f32.v16i16(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <8 x float> %{{.*}})
+  // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16.v8f32.v16i16(<16 x i16> %{{.*}} <16 x i16> %{{.*}} <8 x float> %{{.*}})
   %r1 = rocdl.wmma.f32.16x16x16.bf16 %arg2, %arg2, %arg0 : (vector<16xi16>, vector<16xi16>, vector<8xf32>) -> vector<8xf32>
 
   // f16 -> f16 (OPSEL = {0,1})
-  // CHECK: call <16 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16.v16f16.v16f16(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <16 x half> %{{.*}}, i1 {{.*}})
-  %r2 = rocdl.wmma.f16.16x16x16.f16 %arg1, %arg1, %arg1, %zero : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
+  // CHECK: call <16 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16.v16f16.v16f16(<16 x half> %{{.*}} <16 x half> %{{.*}} <16 x half> %{{.*}} i1 false)
+  %r2 = rocdl.wmma.f16.16x16x16.f16 %arg1, %arg1, %arg1 {opsel = false} : (vector<16xf16>, vector<16xf16>, vector<16xf16>) -> vector<16xf16>
 
   // bf16 -> bf16 (OPSEL = {0,1})
-  // CHECK: call <16 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16.v16i16.v16i16(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <16 x i16> %{{.*}}, i1 {{.*}})
-  %r4 = rocdl.wmma.bf16.16x16x16.bf16 %arg2, %arg2, %arg2, %zero : (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16>
+  // CHECK: call <16 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16.v16i16.v16i16(<16 x i16> %{{.*}} <16 x i16> %{{.*}} <16 x i16> %{{.*}} i1 false)
+  %r4 = rocdl.wmma.bf16.16x16x16.bf16 %arg2, %arg2, %arg2 {opsel = false} : (vector<16xi16>, vector<16xi16>, vector<16xi16>) -> vector<16xi16>
 
   // int8 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1})
-  // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v8i32.v4i32(i1 {{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}, <4 x i32> %{{.*}}, <8 x i32> %{{.*}}, i1 {{.*}})
-  %r5 = rocdl.wmma.i32.16x16x16.iu8 %zero, %arg5, %zero, %arg5, %arg3, %zero : (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32>
+  // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v8i32.v4i32(i1 false, <4 x i32> %{{.*}} i1 false, <4 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false)
+  %r5 = rocdl.wmma.i32.16x16x16.iu8 %arg5, %arg5, %arg3 {signA = false, signB = false, clamp = false} : (vector<4xi32>, vector<4xi32>, vector<8xi32>) -> vector<8xi32>
 
   // int4 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1})
-  // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v8i32.v2i32(i1 {{.*}}, <2 x i32> %{{.*}}, i1 {{.*}}, <2 x i32> %{{.*}}, <8 x i32> %{{.*}}, i1 {{.*}})
-  %r6 = rocdl.wmma.i32.16x16x16.iu4 %zero, %arg4, %zero, %arg4, %arg3, %zero : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
+  // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v8i32.v2i32(i1 false, <2 x i32> %{{.*}} i1 false, <2 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false)
+  %r6 = rocdl.wmma.i32.16x16x16.iu4 %arg4, %arg4, %arg3 {signA = false, signB = false, clamp = false} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32>
 
   // int4 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1})
-  // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x32.iu4.v8i32.v2i32(i1 {{.*}}, <2 x i32> %{{.*}}, i1 {{.*}}, <2 x i32> %{{.*}}, <8 x i32> %{{.*}}, i1 {{.*}})
-  %r6.gfx12 = rocdl.wmma.i32.16x16x32.iu4 %zero, %arg4, %zero, %arg4, %arg3, %zero : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
+  // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x32.iu4.v8i32.v2i32(i1 false, <2 x i32> %{{.*}} i1 false, <2 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false)
+  %r6.gfx12 = rocdl.wmma.i32.16x16x32.iu4 %arg4, %arg4, %arg3 {signA = false, signB = false, clamp = false} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32>
 
   // f32 -> f32
-  // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x4.f32.v4f32.v16f32(i1 {{.*}}, <16 x float> %{{.*}}, i1 {{.*}}, <16 x float> %{{.*}}, i16 0, <4 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
-  %r1.gfx1250 = rocdl.wmma.f32.16x16x4.f32 %zero, %arg10, %zero, %arg10, %zero_i16, %arg11, %zero, %zero : (i1, vector<16xf32>, i1, vector<16xf32>, i16, vector<4xf32>, i1, i1) -> vector<4xf32>
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x4.f32.v4f32.v16f32(i1 false, <16 x float> %{{.*}} i1 false, <16 x float> %{{.*}} i16 0, <4 x float> %{{.*}} i1 false, i1 false)
+  %r1.gfx1250 = rocdl.wmma.f32.16x16x4.f32 %arg10, %arg10, %arg11 {signA = false, signB = false, modC = 0 : i16} : (vector<16xf32>, vector<16xf32>, vector<4xf32>) -> vector<4xf32>
 
   // f16 -> f32
-  // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v32f32.v16f16(i1 {{.*}}, <16 x half> %{{.*}}, i1 {{.*}}, <16 x half> %{{.*}}, i16 0, <32 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
-  %r2.gfx1250 = rocdl.wmma.f32.16x16x32.f16 %zero, %arg1, %zero, %arg1, %zero_i16, %arg12, %zero, %zero : (i1, vector<16xf16>, i1, vector<16xf16>, i16, vector<32xf32>, i1, i1) -> vector<32xf32>
+  // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v32f32.v16f16(i1 false, <16 x half> %{{.*}} i1 false, <16 x half> %{{.*}} i16 0, <32 x float> %{{.*}} i1 false, i1 false)
+  %r2.gfx1250 = rocdl.wmma.f32.16x16x32.f16 %arg1, %arg1, %arg12 {signA = false, signB = false, modC = 0 : i16} : (vector<16xf16>, vector<16xf16>, vector<32xf32>) -> vector<32xf32>
 
   // bf16 -> f32
-  // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.bf16.v32f32.v16bf16(i1 {{.*}}, <16 x bfloat> %{{.*}}, i1 {{.*}}, <16 x bfloat> %{{.*}}, i16 0, <32 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
-  %r3.gfx1250 = rocdl.wmma.f32.16x16x32.bf16 %zero, %arg16, %zero, %arg16, %zero_i16, %arg12, %zero, %zero : (i1, vector<16xbf16>, i1, vector<16xbf16>, i16, vector<32xf32>, i1, i1) -> vector<32xf32>
+  // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.bf16.v32f32.v16bf16(i1 false, <16 x bfloat> %{{.*}} i1 false, <16 x bfloat> %{{.*}} i16 0, <32 x float> %{{.*}} i1 false, i1 false)
+  %r3.gfx1250 = rocdl.wmma.f32.16x16x32.bf16 %arg16, %arg16, %arg12 {signA = false, signB = false, modC = 0 : i16} : (vector<16xbf16>, vector<16xbf16>, vector<32xf32>) -> vector<32xf32>
 
   // f16 -> f16
-  // CHECK: call <32 x half> @llvm.amdgcn.wmma.f16.16x16x32.f16.v32f16.v16f16(i1 {{.*}}, <16 x half> %{{.*}}, i1 {{.*}}, <16 x half> %{{.*}}, i16 0, <32 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
-  %r4.gfx1250 = rocdl.wmma.f16.16x16x32.f16 %zero, %arg1, %zero, %arg1, %zero_i16, %arg9, %zero, %zero : (i1, vector<16xf16>, i1, vector<16xf16>, i16, vector<32xf16>, i1, i1) -> vector<32xf16>
+  // CHECK: call <32 x half> @llvm.amdgcn.wmma.f16.16x16x32.f16.v32f16.v16f16(i1 false, <16 x half> %{{.*}} i1 false, <16 x half> %{{.*}} i16 0, <32 x half> %{{.*}} i1 false, i1 false)
+  %r4.gfx1250 = rocdl.wmma.f16.16x16x32.f16 %arg1, %arg1, %arg9 {signA = false, signB = false, modC = 0 : i16} : (vector<16xf16>, vector<16xf16>, vector<32xf16>) -> vector<32xf16>
 
   // bf16 -> bf16
-  // CHECK: call <32 x bfloat> @llvm.amdgcn.wmma.bf16.16x16x32.bf16.v32bf16.v16bf16(i1 {{.*}}, <16 x bfloat> %{{.*}}, i1 {{.*}}, <16 x bfloat> %{{.*}}, i16 0, <32 x bfloat> %{{.*}}, i1 {{.*}}, i1 {{.*}})
-  %r5.gfx1250 = rocdl.wmma.bf16.16x16x32.bf16 %zero, %arg16, %zero, %arg16, %zero_i16, %arg17, %zero, %zero : (i1, vector<16xbf16>, i1, vector<16xbf16>, i16, vector<32xbf16>, i1, i1) -> vector<32xbf16>
+  // CHECK: call <32 x bfloat> @llvm.amdgcn.wmma.bf16.16x16x32.bf16.v32bf16.v16bf16(i1 false, <16 x bfloat> %{{.*}} i1 false, <16 x bfloat> %{{.*}} i16 0, <32 x bfloat> %{{.*}} i1 false, i1 false)
+  %r5.gfx1250 = rocdl.wmma.bf16.16x16x32.bf16 %arg16, %arg16, %arg17 {signA = false, signB = false, modC = 0 : i16} : (vector<16xbf16>, vector<16xbf16>, vector<32xbf16>) -> vector<32xbf16>
 
   // bf16 -> bf16 / f32
-  // CHECK: call <32 x bfloat> @llvm.amdgcn.wmma.bf16f32.16x16x32.bf16.v32bf16.v16bf16.v32f32(i1 {{.*}}, <16 x bfloat> %{{.*}}, i1 {{.*}}, <16 x bfloat> %{{.*}}, i16 0, <32 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
-  %r6.gfx1250 = rocdl.wmma.bf16f32.16x16x32.bf16 %zero, %arg16, %zero, %arg16, %zero_i16, %arg12, %zero, %zero : (i1, vector<16xbf16>, i1, vector<16xbf16>, i16, vector<32xf32>, i1, i1) -> vector<32xbf16>
+  // CHECK: call <32 x bfloat> @llvm.amdgcn.wmma.bf16f32.16x16x32.bf16.v32bf16.v16bf16.v32f32(i1 false, <16 x bfloat> %{{.*}} i1 false, <16 x bfloat> %{{.*}} i16 0, <32 x float> %{{.*}} i1 false, i1 false)
+  %r6.gfx1250 = rocdl.wmma.bf16f32.16x16x32.bf16 %arg16, %arg16, %arg12 {signA = false, signB = false, modC = 0 : i16} : (vector<16xbf16>, vector<16xbf16>, vector<32xf32>) -> vector<32xbf16>
 
   // f8/bf8 -> f16/f32
-  // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
-  %r7.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_fp8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32>
-
-  // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
-  %r8.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_bf8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32>
+  // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.fp8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false)
+  %r7.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_fp8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32>
 
-  // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.bf8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
-  %r9.gfx1250 = rocdl.wmma.f32.16x16x64.bf8_fp8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32>
+  // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.bf8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false)
+  %r8.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_bf8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32>
 
-  // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.bf8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
-  %r10.gfx1250 = rocdl.wmma.f32.16x16x64.bf8_bf8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32>
+  // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.bf8.fp8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false)
+  %r9.gfx1250 = rocdl.wmma.f32.16x16x64.bf8_fp8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32>
 
-  // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.fp8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
-  %r11.gfx1250 = rocdl.wmma.f16.16x16x64.fp8_fp8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16>
+  // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.bf8.bf8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false)
+  %r10.gfx1250 = rocdl.wmma.f32.16x16x64.bf8_bf8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32>
 
-  // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.fp8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
-  %r12.gfx1250 = rocdl.wmma.f16.16x16x64.fp8_bf8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16>
+  // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.fp8.fp8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false)
+  %r11.gfx1250 = rocdl.wmma.f16.16x16x64.fp8_fp8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16>
 
-  // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.bf8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
-  %r13.gfx1250 = rocdl.wmma.f16.16x16x64.bf8_fp8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16>
+  // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.fp8.bf8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false)
+  %r12.gfx1250 = rocdl.wmma.f16.16x16x64.fp8_bf8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16>
 
-  // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.bf8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
-  %r14.gfx1250 = rocdl.wmma.f16.16x16x64.bf8_bf8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16>
+  // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.bf8.fp8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false)
+  %r13.gfx1250 = rocdl.wmma.f16.16x16x64.bf8_fp8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16>
 
-  // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.fp8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
-  %r15.gfx1250 = rocdl.wmma.f32.16x16x128.fp8_fp8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32>
+  // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.bf8.bf8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false)
+  %r14.gfx1250 = rocdl.wmma.f16.16x16x64.bf8_bf8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16>
 
-  // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.fp8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
-  %r16.gfx1250 = rocdl.wmma.f32.16x16x128.fp8_bf8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32>
+  // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.fp8.fp8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false)
+  %r15.gfx1250 = rocdl.wmma.f32.16x16x128.fp8_fp8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32>
 
-  // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.bf8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
-  %r17.gfx1250 = rocdl.wmma.f32.16x16x128.bf8_fp8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32>
+  // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.fp8.bf8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false)
+  %r16.gfx1250 = rocdl.wmma.f32.16x16x128.fp8_bf8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32>
 
-  // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.bf8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
-  %r18.gfx1250 = rocdl.wmma.f32.16x16x128.bf8_bf8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32>
+  // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.bf8.fp8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false)
+  %r17.gfx1250 = rocdl.wmma.f32.16x16x128.bf8_fp8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32>
 
-  // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.fp8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
-  %r19.gfx1250 = rocdl.wmma.f16.16x16x128.fp8_fp8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16>
+  // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.bf8.bf8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false)
+  %r18.gfx1250 = rocdl.wmma.f32.16x16x128.bf8_bf8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32>
 
-  // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.fp8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
-  %r20.gfx1250 = rocdl.wmma.f16.16x16x128.fp8_bf8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16>
+  // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.fp8.fp8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false)
+  %r19.gfx1250 = rocdl.wmma.f16.16x16x128.fp8_fp8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16>
+  
+  // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.fp8.bf8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false)
+  %r20.gfx1250 = rocdl.wmma.f16.16x16x128.fp8_bf8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16>
 
-  // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.bf8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
-  %r21.gfx1250 = rocdl.wmma.f16.16x16x128.bf8_fp8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16>
+  // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.bf8.fp8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false)
+  %r21.gfx1250 = rocdl.wmma.f16.16x16x128.bf8_fp8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16>
 
-  // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.bf8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
-  %r22.gfx1250 = rocdl.wmma.f16.16x16x128.bf8_bf8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16>
+  // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.bf8.bf8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false)
+  %r22.gfx1250 = rocdl.wmma.f16.16x16x128.bf8_bf8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16>
 
   // iu8 -> i32
-  // CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 {{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}, <4 x i32> %{{.*}}, <64 x i32> %{{.*}}, i1 {{.*}}, i1 {{.*}})
-  %r23.gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %zero, %arg5, %zero, %arg5, %arg14, %zero, %zero : (i1, vector<4xi32>, i1, vector<4xi32>, vector<64xi32>, i1, i1) -> vector<64xi32>
+  // CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 false, <4 x i32> %{{.*}} i1 false, <4 x i32> %{{.*}} <64 x i32> %{{.*}} i1 false, i1 false)
+  %r23.gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %arg5, %arg5, %arg14 {signA = false, signB = false} : (vector<4xi32>, vector<4xi32>, vector<64xi32>) -> vector<64xi32>
 
   // ---- Wave64 -----
 
   // f16 -> f32
-  // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16.v4f32.v16f16(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <4 x float> %{{.*}})
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16.v4f32.v16f16(<16 x half> %{{.*}} <16 x half> %{{.*}} <4 x float> %{{.*}})
   %r7 = rocdl.wmma.f32.16x16x16.f16 %arg1, %arg1, %arg6 : (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32>
 
   // bf16 -> f32
-  // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16.v4f32.v16i16(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <4 x float> %{{.*}})
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16.v4f32.v16i16(<16 x i16> %{{.*}} <16 x i16> %{{.*}} <4 x float> %{{.*}})
   %r8 = rocdl.wmma.f32.16x16x16.bf16 %arg2, %arg2, %arg6 : (vector<16xi16>, vector<16xi16>, vector<4xf32>) -> vector<4xf32>
 
   // f16 -> f16 (OPSEL = {0,1})
-  // CHECK: call <8 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16.v8f16.v16f16(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <8 x half> %{{.*}}, i1 {{.*}})
-  %r9 = rocdl.wmma.f16.16x16x16.f16 %arg1, %arg1, %arg7, %zero : (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16>
+  // CHECK: call <8 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16.v8f16.v16f16(<16 x half> %{{.*}} <16 x half> %{{.*}} <8 x half> %{{.*}} i1 false)
+  %r9 = rocdl.wmma.f16.16x16x16.f16 %arg1, %arg1, %arg7 {opsel = false} : (vector<16xf16>, vector<16xf16>, vector<8xf16>) -> vector<8xf16>
 
   // bf16 -> bf16 (OPSEL = {0,1})
-  // CHECK: call <8 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16.v8i16.v16i16(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <8 x i16> %{{.*}}, i1 {{.*}})
-  %r11 = rocdl.wmma.bf16.16x16x16.bf16 %arg2, %arg2, %arg8, %zero : (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16>
+  // CHECK: call <8 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16.v8i16.v16i16(<16 x i16> %{{.*}} <16 x i16> %{{.*}} <8 x i16> %{{.*}} i1 false)
+  %r11 = rocdl.wmma.bf16.16x16x16.bf16 %arg2, %arg2, %arg8 {opsel = false} : (vector<16xi16>, vector<16xi16>, vector<8xi16>) -> vector<8xi16>
 
   // int8 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1})
-  // CHECK: call <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v4i32.v4i32(i1 {{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i1 {{.*}})
-  %r12 = rocdl.wmma.i32.16x16x16.iu8 %zero, %arg5, %zero, %arg5, %arg5, %zero : (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32>
+  // CHECK: call <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v4i32.v4i32(i1 false, <4 x i32> %{{.*}} i1 false, <4 x i32> %{{.*}} <4 x i32> %{{.*}} i1 true)
+  %r12 = rocdl.wmma.i32.16x16x16.iu8 %arg5, %arg5, %arg5 {signA = false, signB = false, clamp = true} : (vector<4xi32>, vector<4xi32>, vector<4xi32>) -> vector<4xi32>
 
   // int4 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1})
-  // CHECK: call <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v4i32.v2i32(i1 {{.*}}, <2 x i32> %{{.*}}, i1 {{.*}}, <2 x i32> %{{.*}}, <4 x i32> %{{.*}}, i1 {{.*}})
-  %r13 = rocdl.wmma.i32.16x16x16.iu4 %zero, %arg4, %zero, %arg4, %arg5, %zero : (i1, vector<2xi32>, i1, vector<2xi32>, vector<4xi32>, i1) -> vector<4xi32>
+  // CHECK: call <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v4i32.v2i32(i1 false, <2 x i32> %{{.*}} i1 false, <2 x i32> %{{.*}} <4 x i32> %{{.*}} i1 true)
+  %r13 = rocdl.wmma.i32.16x16x16.iu4 %arg4, %arg4, %arg5 {signA = false, signB = false, clamp = true} : (vector<2xi32>, vector<2xi32>, vector<4xi32>) -> vector<4xi32>
 
   llvm.return %r0 : vector<8xf32>
 }

>From fef3acee42ac9107b75a2399191a88bbdc24f3dd Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Fri, 7 Nov 2025 17:03:09 -0600
Subject: [PATCH 2/4] removing unnecessary changes

Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
 mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 6 ++++--
 mlir/test/Target/LLVMIR/rocdl.mlir                  | 1 -
 2 files changed, 4 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index e3425fe703b83..59ad3259cfdb6 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -709,7 +709,8 @@ static void wmmaPushInputOperand(
     } else if (elemType.isSignedInteger()) {
       localIsUnsigned = false;
     }
-    attrs.push_back(NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
+    attrs.push_back(
+        NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
   }
 
   int64_t numBits =
@@ -1290,7 +1291,6 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   LogicalResult
   matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-
     Location loc = op.getLoc();
     auto outType =
         typeConverter->convertType<VectorType>(op.getDestD().getType());
@@ -1329,6 +1329,7 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
           rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
 
     std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
+
     if (!maybeIntrinsic.has_value())
       return op.emitOpError("no intrinsic matching WMMA on the given chipset");
 
@@ -1350,6 +1351,7 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
     loweredOp.addOperands(operands);
     loweredOp.addAttributes(attrs);
     Operation *lowered = rewriter.create(loweredOp);
+
     Operation *maybeCastBack = lowered;
     if (rawOutType != outType)
       maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 727e7e189e544..6603a5e3d125c 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -877,7 +877,6 @@ llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : v
                       %arg14 : vector<64xi32>, %arg15 : vector<64xf16>, %arg16 : vector<16xbf16>, %arg17 : vector<32xbf16>) -> vector<8xf32> {
 
   // ---- Wave32 -----
-  
   // f16 -> f32
   // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16.v8f32.v16f16(<16 x half> %{{.*}} <16 x half> %{{.*}} <8 x float> %{{.*}})
   %r0 = rocdl.wmma.f32.16x16x16.f16 %arg1, %arg1, %arg0 : (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>

>From 76928f0bebf1e57bdb43bddd5fbf331821dc0d30 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Mon, 10 Nov 2025 10:31:53 -0600
Subject: [PATCH 3/4] [mlir][ROCDL] Address review comments

- Use SmallVectorImpl for output parameters per LLVM coding standards
- Fix type casting bug (getSourceA -> getSourceB on line 1309)
- Use cast<> instead of dyn_cast<> for guaranteed vector types
- Simplify assembly format with functional-type shorthand

Addresses review comments from @kuhar on PR #167041
---
 mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td       | 14 +++++++-------
 .../lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 14 +++++++-------
 2 files changed, 14 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 282f30d68269f..27672e683b8ab 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -590,7 +590,7 @@ class ROCDL_WMMA_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemon
              LLVM_ScalarOrVectorOf<CD>:$C)> {
   let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
   let assemblyFormat = [{
-    $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+    $A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
   }];
 }
 
@@ -603,7 +603,7 @@ class ROCDL_WMMA_Opsel_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<
              DefaultValuedAttr<I1Attr, "0">:$opsel)> {
   let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
   let assemblyFormat = [{
-    $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+    $A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
   }];
 }
 
@@ -618,7 +618,7 @@ class ROCDL_WMMA_IU_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mne
              DefaultValuedAttr<I1Attr, "0">:$clamp)> {
   let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
   let assemblyFormat = [{
-    $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+    $A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
   }];
 }
 
@@ -635,7 +635,7 @@ class ROCDL_WMMA_ModsAll_Reuse_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL
              DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
   let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
   let assemblyFormat = [{
-    $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+    $A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
   }];
 }
 
@@ -650,7 +650,7 @@ class ROCDL_WMMA_ModsC_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<
              DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
   let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
   let assemblyFormat = [{
-    $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+    $A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
   }];
 }
 
@@ -667,7 +667,7 @@ class ROCDL_WMMA_ModsAll_Diff_IntrOp<string mnemonic, Type AB, Type C, Type D> :
              DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
   let results = (outs LLVM_ScalarOrVectorOf<D>:$res);
   let assemblyFormat = [{
-    $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+    $A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
   }];
 }
 
@@ -683,7 +683,7 @@ class ROCDL_WMMA_ModsAB_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp
              DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
   let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
   let assemblyFormat = [{
-    $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+    $A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
   }];
 }
 
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 59ad3259cfdb6..caa9a054ee442 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -682,8 +682,8 @@ static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
 static void wmmaPushInputOperand(
     ConversionPatternRewriter &rewriter, Location loc,
     const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput,
-    Value mlirInput, SmallVector<Value, 4> &operands,
-    SmallVector<NamedAttribute, 4> &attrs, StringRef attrName) {
+    Value mlirInput, SmallVectorImpl<Value> &operands,
+    SmallVectorImpl<NamedAttribute> &attrs, StringRef attrName) {
   Type inputType = llvmInput.getType();
   auto vectorType = dyn_cast<VectorType>(inputType);
   if (!vectorType) {
@@ -741,8 +741,8 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
                                   Location loc,
                                   const TypeConverter *typeConverter,
                                   Value output, int32_t subwordOffset,
-                                  bool clamp, SmallVector<Value, 4> &operands,
-                                  SmallVector<NamedAttribute, 4> &attrs) {
+                                  bool clamp, SmallVectorImpl<Value> &operands,
+                                  SmallVectorImpl<NamedAttribute> &attrs) {
   Type inputType = output.getType();
   auto vectorType = dyn_cast<VectorType>(inputType);
   Type elemType = vectorType.getElementType();
@@ -1305,9 +1305,9 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
     // The WMMA operations represent vectors of bf16s as vectors of i16s
     // (except on gfx1250), so we need to bitcast bfloats to i16 and then
     // bitcast them back.
-    auto aType = dyn_cast<VectorType>(adaptor.getSourceA().getType());
-    auto bType = dyn_cast<VectorType>(adaptor.getSourceA().getType());
-    auto destCType = dyn_cast<VectorType>(adaptor.getDestC().getType());
+    auto aType = cast<VectorType>(adaptor.getSourceA().getType());
+    auto bType = cast<VectorType>(adaptor.getSourceB().getType());
+    auto destCType = cast<VectorType>(adaptor.getDestC().getType());
     bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
     bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
     bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;

>From b768b8acee2bcac3ef0de22131737a708182ca80 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Mon, 10 Nov 2025 12:12:10 -0600
Subject: [PATCH 4/4] [mlir][ROCDL] Fix Python bindings test for WMMA
 refactoring

Update the Python test to match the new WMMA API where operands are
passed as separate positional arguments and attributes (like opsel) are
passed as keyword arguments with boolean values.

Changes:
- Changed from list-style arguments to positional arguments
- Changed opsel from MLIR Value operand to Python bool attribute
- Removed unnecessary false constant creation

This fixes the CI test failure where the Python bindings were expecting
the old API signature.
---
 mlir/test/python/dialects/rocdl.py | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/mlir/test/python/dialects/rocdl.py b/mlir/test/python/dialects/rocdl.py
index a4a50afa966c7..c73a536e03820 100644
--- a/mlir/test/python/dialects/rocdl.py
+++ b/mlir/test/python/dialects/rocdl.py
@@ -29,13 +29,12 @@ def testSmoke():
     a_frag = arith.constant(v16f32, f32_array)
     b_frag = arith.constant(v16f32, f32_array)
     c_frag = arith.constant(v16f32, f32_array)
-    false = arith.constant(T.bool(), False)
 
-    c_frag = rocdl.wmma_f16_16x16x16_f16(v16f32, [a_frag, b_frag, c_frag, false])
-    # CHECK: %{{.*}} = rocdl.wmma.f16.16x16x16.f16
+    c_frag = rocdl.wmma_f16_16x16x16_f16(v16f32, a_frag, b_frag, c_frag, opsel=False)
+    # CHECK: %{{.*}} = "rocdl.wmma.f16.16x16x16.f16"
     print(c_frag)
     assert isinstance(c_frag, OpView)
-    # CHECK: Value(%{{.*}} = rocdl.wmma.f16.16x16x16.f16
-    c_frag = rocdl.wmma_f16_16x16x16_f16_(v16f32, [a_frag, b_frag, c_frag, false])
+    # CHECK: Value(%{{.*}} = "rocdl.wmma.f16.16x16x16.f16"
+    c_frag = rocdl.wmma_f16_16x16x16_f16_(v16f32, a_frag, b_frag, c_frag, opsel=False)
     print(c_frag)
     assert isinstance(c_frag, Value)



More information about the Mlir-commits mailing list