[Mlir-commits] [mlir] [mlir][amdgpu][rocdl] Add gfx1250 wmma ops (PR #165064)

Jakub Kuderski llvmlistbot at llvm.org
Tue Oct 28 08:35:52 PDT 2025


https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/165064

>From 4742f281542200695b21081f6cb239f356a42413 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Fri, 24 Oct 2025 18:40:59 -0400
Subject: [PATCH 1/3] [mlir][amdgpu][rocdl] Add gfx1250 wmma ops

Update `amdgpu.wmma` op definition and implement amdgpu to rocdl
conversion for new variants.
---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td |  27 ++--
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 117 ++++++++++++++++--
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  |  10 +-
 .../AMDGPUToROCDL/wmma-gfx1250.mlir           |  89 +++++++++++++
 mlir/test/Dialect/AMDGPU/invalid.mlir         |  60 +++++++--
 mlir/test/Dialect/AMDGPU/ops.mlir             |  35 ++++++
 6 files changed, 306 insertions(+), 32 deletions(-)
 create mode 100644 mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 37db096f1ba75..45cb67f0eee4a 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -912,9 +912,10 @@ def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN
                                    VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
 def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
 // wmma
-def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
-                             VectorOfLengthAndType<[4, 8, 16], [I8, SI8, UI8]>,
-                             VectorOfLengthAndType<[4, 8], [F8E4M3FN, F8E5M2]>,
+def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[2], [F32]>,
+                             VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
+                             VectorOfLengthAndType<[4, 8, 16, 32], [I8, SI8, UI8]>,
+                             VectorOfLengthAndType<[4, 8, 32, 64], [F8E4M3FN, F8E5M2]>,
                              VectorOfLengthAndType<[4, 8, 16], [I<4>, SI<4>, UI<4>]>]>;
 def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
                               VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>;
@@ -992,7 +993,7 @@ def AMDGPU_WMMAOp :
     Arguments<(ins
                    ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$m,
                    ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n,
-                   ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$k,
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[4, 16, 32, 64, 128]>]>:$k,
                    WMMAInTypes:$sourceA,
                    WMMAInTypes:$sourceB,
                    WMMAOutTypes:$destC,
@@ -1005,8 +1006,14 @@ def AMDGPU_WMMAOp :
   let description = [{
     The `amdgpu.wmma` op is an MLIR wrapper around intrinsics for various `wmma`
     instructions in the AMDGPU architecture, which perform matrix multiplication.
-    Note that all wmma intrinsics have M=N=16 dimensions but vary by in allowed K
-    dimensions.
+
+    On gfx11/RDNA3, wmma intrinsics have M=N=K=16 dimensions.
+
+    On gfx12/RDNA4, wmma intrinsics have M=N=16 dimensions and support K=16 for
+    all element types, and K=32 for i4 sources.
+
+    On gfx1250, wmma intrinsics have M=N=16 and K dimensions of 4, 32, 64, or 128,
+    depending on the element types.
 
     On gfx11/RDNA3, emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16
     (or 16xbf16) vector containing only 8 valid values:
@@ -1022,7 +1029,13 @@ def AMDGPU_WMMAOp :
 
     Example:
     ```mlir
-      %0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<16xf16>, vector<16xf16>, vector<8xf16>
+      %0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<8xf16>, vector<8xf16>, vector<8xf16>
+
+      %1 = amdgpu.wmma 16x16x64 %matD * %matE + %matF : vector<32xi8>, vector<8xf32>, vector<8xf32>
+
+      %2 = amdgpu.wmma 16x16x128 %matG * %matH + %matI : vector<64xf4E2M1FN>, vector<64xf4E2M1FN>, vector<8xf32>
+
+      %3 = amdgpu.wmma 16x16x4 %matJ * %matK + %matL : vector<2xf32>, vector<2xf32>, vector<8xf32>
     ```
   }];
   let assemblyFormat = [{
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 478b6aaaec83a..90e731c11da62 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1002,8 +1002,13 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
   Type elemDestType = destVectorType.getElementType();
 
   const uint32_t k = wmma.getK();
+  const bool isRDNA3 = chipset.majorVersion == 11;
+  const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
 
   if (k == 16) {
+    if (!isRDNA3 && !isRDNA4) // gfx1250 does not have any wmma ops with k=16.
+      return std::nullopt;
+
     if (elemSourceType.isF16() && elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
     if (elemSourceType.isBF16() && elemDestType.isF32())
@@ -1019,34 +1024,124 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
         return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
     }
   }
-  if (chipset.majorVersion < 12)
+  if (isRDNA3)
     return std::nullopt;
 
+  using fp8 = Float8E4M3FNType;
+  using bf8 = Float8E5M2Type;
+
   // gfx12+
   if (k == 16) {
-    if (isa<Float8E4M3FNType>(elemSourceType) &&
-        isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
+    if (!isRDNA4) // gfx1250 does not have any wmma ops with k=16.
+      return std::nullopt;
+
+    if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
+        elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
-    if (isa<Float8E4M3FNType>(elemSourceType) &&
-        isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
+    if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
+        elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
-    if (isa<Float8E5M2Type>(elemSourceType) &&
-        isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
+    if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
+        elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
-    if (isa<Float8E5M2Type>(elemSourceType) &&
-        isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
+    if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
+        elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
+
     if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
       return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
 
     return std::nullopt;
   }
   if (k == 32) {
-    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
-      return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+    if (isRDNA4) {
+      if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
+        return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+      return std::nullopt;
+    }
+
+    // gfx1250
+    if (elemSourceType.isF16() && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
+    if (elemSourceType.isBF16() && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
+    if (elemSourceType.isF16() && elemDestType.isF16())
+      return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
+    if (elemSourceType.isBF16() && elemDestType.isBF16())
+      return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
+
+    return std::nullopt;
+  }
+
+  if (isRDNA4)
+    return std::nullopt;
+
+  // gfx1250
+  if (k == 4) {
+    if (elemSourceType.isF32() && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
     return std::nullopt;
   }
 
+  if (k == 64) {
+    if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
+    }
+    if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
+    }
+    if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
+    }
+    if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
+    }
+    if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
+      return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
+
+    return std::nullopt;
+  }
+
+  if (k == 128) {
+    if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
+    }
+    if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
+    }
+    if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
+    }
+    if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
+    }
+
+    return std::nullopt;
+  }
   llvm_unreachable("unhandled WMMA case");
 }
 
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 585b6dacfa648..b46568a423aa7 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -399,13 +399,15 @@ LogicalResult WMMAOp::verify() {
 
   if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
     return emitOpError(
-               "source element types much match (except for fp8) but have ")
+               "source element types much match (except for fp8/bf8) but have ")
            << sourceAType << " and " << sourceBType;
   }
 
-  if (!sourceAElemType.isInteger(4) && getK() != 16) {
-    return emitOpError("K dimension must be 16 for source element type ")
-           << sourceAElemType;
+  if (isSrcFloat) {
+    if (getClamp())
+      return emitOpError("clamp flag is not supported for float types");
+    if (getUnsignedA() || getUnsignedB())
+      return emitOpError("unsigned flags are not supported for float types");
   }
   return success();
 }
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
new file mode 100644
index 0000000000000..bcbdef040ebe3
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
@@ -0,0 +1,89 @@
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 --allow-unregistered-dialect | FileCheck %s
+
+// CHECK-LABEL: @wmma_k4
+func.func @wmma_k4(%arg0 : vector<2xf32>, %arg1 : vector<8xf32>) {
+  // CHECK: rocdl.wmma.f32.16x16x4.f32 %arg0, %arg0, %arg1
+  amdgpu.wmma 16x16x4 %arg0 * %arg0 + %arg1 : vector<2xf32>, vector<2xf32>, vector<8xf32>
+  func.return
+}
+
+// CHECK-LABEL: @wmma_k32
+func.func @wmma_k32(%arg0 : vector<16xf16>, %arg1 : vector<16xbf16>, %arg2 : vector<8xf32>,
+                    %arg3 : vector<8xf16>, %arg4 : vector<8xbf16>) {
+  // CHECK: rocdl.wmma.f32.16x16x32.f16 %arg0, %arg0, %arg2
+  amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x32.f16 %arg0, %arg0, {{.*}} : (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg3 : vector<16xf16>, vector<16xf16>, vector<8xf16>
+
+  // CHECK: rocdl.wmma.f32.16x16x32.bf16 {{.*}}, {{.*}}, %arg2
+  amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.bf16.16x16x32.bf16 {{.*}}, {{.*}}, {{.*}}, {{.*}} : (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1)
+  amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg4 : vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
+
+  func.return
+}
+
+// CHECK-LABEL: @wmma_k64
+func.func @wmma_k64(%arg0 : vector<32xi8>, %arg1 : vector<32xf8E4M3FN>, %arg2 : vector<32xf8E5M2>,
+                    %arg3 : vector<8xi32>, %arg4 : vector<8xf32>, %arg5 : vector<8xf16>) {
+  // CHECK: rocdl.wmma.i32.16x16x64.iu8 {{.*}}, {{.*}}, {{.*}}, {{.*}}, %arg3, {{.*}}
+  amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg3 {clamp} : vector<32xi8>, vector<32xi8>, vector<8xi32>
+
+  // CHECK: rocdl.wmma.f32.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg4
+  amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf16>
+
+  // CHECK: rocdl.wmma.f32.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg4
+  amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf16>
+
+  // CHECK: rocdl.wmma.f32.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg4
+  amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg4 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg5 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf16>
+
+  // CHECK: rocdl.wmma.f32.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg4
+  amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg4 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg5 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf16>
+
+  func.return
+}
+
+// CHECK-LABEL: @wmma_k128
+func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
+                     %arg2 : vector<8xf32>, %arg3 : vector<8xf16>) {
+  // CHECK: rocdl.wmma.f32.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg2
+  amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf16>
+
+  // CHECK: rocdl.wmma.f32.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg2
+  amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf16>
+
+  // CHECK: rocdl.wmma.f32.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg2
+  amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg2 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg3 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf16>
+
+  // CHECK: rocdl.wmma.f32.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg2
+  amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg2 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg3 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf16>
+
+  func.return
+}
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 57847641a2d03..594b4921bbe56 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -156,14 +156,6 @@ func.func @wmma_no_k_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector
 
 // -----
 
-func.func @wmma_wrong_m_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
-  // expected-error at +1 {{'amdgpu.wmma' op attribute 'm' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16}}}
-  %0 = amdgpu.wmma 32x16x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
-  func.return %0 : vector<8xi32>
-}
-
-// -----
-
 func.func @wmma_wrong_n_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
   // expected-error at +1 {{'amdgpu.wmma' op attribute 'n' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16}}}
   %0 = amdgpu.wmma 16x32x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
@@ -173,14 +165,62 @@ func.func @wmma_wrong_n_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vec
 // -----
 
 func.func @wmma_wrong_k_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
-  // expected-error at +1 {{'amdgpu.wmma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}}
+  // expected-error at +1 {{'amdgpu.wmma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {4, 16, 32, 64, 128}}}
   %0 = amdgpu.wmma 16x16x24 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
   func.return %0 : vector<8xi32>
 }
 
 // -----
 
-// Missinng `resetOffset`
+func.func @wmma_source_length_mismatch(%arg0 : vector<8xf16>, %arg1 : vector<16xf16>, %arg2 : vector<8xf32>) -> vector<8xf32> {
+  // expected-error at +1 {{'amdgpu.wmma' op source vectors have different lengths}}
+  %0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xf16>, vector<16xf16>, vector<8xf32>
+  func.return %0 : vector<8xf32>
+}
+
+// -----
+
+func.func @wmma_mismatched_float_types(%arg0 : vector<8xf16>, %arg1 : vector<8xbf16>, %arg2 : vector<8xf32>) -> vector<8xf32> {
+  // expected-error at +1 {{'amdgpu.wmma' op source element types much match (except for fp8/bf8)}}
+  %0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
+  func.return %0 : vector<8xf32>
+}
+
+// -----
+
+func.func @wmma_mismatched_int_types(%arg0 : vector<8xi8>, %arg1 : vector<8xi4>, %arg2 : vector<8xi32>) -> vector<8xi32> {
+  // expected-error at +1 {{'amdgpu.wmma' op source element types much match (except for fp8/bf8)}}
+  %0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xi8>, vector<8xi4>, vector<8xi32>
+  func.return %0 : vector<8xi32>
+}
+
+// -----
+
+func.func @wmma_clamp_float(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+  // expected-error at +1 {{'amdgpu.wmma' op clamp flag is not supported for float types}}
+  %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {clamp} : vector<8xf16>, vector<8xf16>, vector<8xf32>
+  func.return %0 : vector<8xf32>
+}
+
+// -----
+
+func.func @wmma_unsignedA_float(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+  // expected-error at +1 {{'amdgpu.wmma' op unsigned flags are not supported for float types}}
+  %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {unsignedA} : vector<8xf16>, vector<8xf16>, vector<8xf32>
+  func.return %0 : vector<8xf32>
+}
+
+// -----
+
+func.func @wmma_unsignedB_float(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+  // expected-error at +1 {{'amdgpu.wmma' op unsigned flags are not supported for float types}}
+  %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {unsignedB} : vector<8xf16>, vector<8xf16>, vector<8xf32>
+  func.return %0 : vector<8xf32>
+}
+
+// -----
+
+// Missing `resetOffset`
 func.func @fat_raw_buffer_cast_stripped_offset(%m: memref<8xi32, strided<[1], offset: ?>, #gpu.address_space<global>>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
   // expected-error at +1 {{'amdgpu.fat_raw_buffer_cast' op expected result type to be 'memref<8xi32, strided<[1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>>' but got 'memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>'}}
   %ret = amdgpu.fat_raw_buffer_cast %m : memref<8xi32, strided<[1], offset: ?>, #gpu.address_space<global>> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index a33096750ee23..09134cb4704bb 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -586,6 +586,41 @@ func.func @wmma_i32_16x16x32_i4(%arg0 : vector<16xi4>, %arg1 : vector<8xi32>) ->
   func.return %0 : vector<8xi32>
 }
 
+// CHECK-LABEL: func @wmma_f32_16x16x4_f32
+func.func @wmma_f32_16x16x4_f32(%arg0 : vector<2xf32>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+  // CHECK: amdgpu.wmma 16x16x4
+  %0 = amdgpu.wmma 16x16x4 %arg0 * %arg0 + %arg1 : vector<2xf32>, vector<2xf32>, vector<8xf32>
+  func.return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @wmma_f32_16x16x64_f8
+func.func @wmma_f32_16x16x64_f8(%arg0 : vector<32xf8E4M3FN>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+  // CHECK: amdgpu.wmma 16x16x64
+  %0 = amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg1 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf32>
+  func.return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @wmma_f32_16x16x64_bf8
+func.func @wmma_f32_16x16x64_bf8(%arg0 : vector<32xf8E5M2>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+  // CHECK: amdgpu.wmma 16x16x64
+  %0 = amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg1 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf32>
+  func.return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @wmma_f16_16x16x64_bf8
+func.func @wmma_f16_16x16x64_bf8(%arg0 : vector<32xf8E5M2>, %arg1 : vector<8xf16>) -> vector<8xf16> {
+  // CHECK: amdgpu.wmma 16x16x64
+  %0 = amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg1 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf16>
+  func.return %0 : vector<8xf16>
+}
+
+// CHECK-LABEL: func @wmma_f16_16x16x64_f8
+func.func @wmma_f16_16x16x64_f8(%arg0 : vector<32xf8E4M3FN>, %arg1 : vector<8xf16>) -> vector<8xf16> {
+  // CHECK: amdgpu.wmma 16x16x64
+  %0 = amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg1 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf16>
+  func.return %0 : vector<8xf16>
+}
+
 // CHECK-LABEL: func @swizzle_bitmode
 func.func @swizzle_bitmode(%arg0 : f32) -> f32 {
   // CHECK: amdgpu.swizzle_bitmode

>From e616cf6ea9e386aad6316d267076fd08f194836e Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sat, 25 Oct 2025 07:36:14 -0400
Subject: [PATCH 2/3] Fix typo

---
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 2 +-
 mlir/test/Dialect/AMDGPU/invalid.mlir        | 4 ++--
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index b46568a423aa7..df955fc90b45f 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -399,7 +399,7 @@ LogicalResult WMMAOp::verify() {
 
   if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
     return emitOpError(
-               "source element types much match (except for fp8/bf8) but have ")
+               "source element types must match (except for fp8/bf8) but have ")
            << sourceAType << " and " << sourceBType;
   }
 
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 594b4921bbe56..4c6f62a045405 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -181,7 +181,7 @@ func.func @wmma_source_length_mismatch(%arg0 : vector<8xf16>, %arg1 : vector<16x
 // -----
 
 func.func @wmma_mismatched_float_types(%arg0 : vector<8xf16>, %arg1 : vector<8xbf16>, %arg2 : vector<8xf32>) -> vector<8xf32> {
-  // expected-error at +1 {{'amdgpu.wmma' op source element types much match (except for fp8/bf8)}}
+  // expected-error at +1 {{'amdgpu.wmma' op source element types must match (except for fp8/bf8)}}
   %0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
   func.return %0 : vector<8xf32>
 }
@@ -189,7 +189,7 @@ func.func @wmma_mismatched_float_types(%arg0 : vector<8xf16>, %arg1 : vector<8xb
 // -----
 
 func.func @wmma_mismatched_int_types(%arg0 : vector<8xi8>, %arg1 : vector<8xi4>, %arg2 : vector<8xi32>) -> vector<8xi32> {
-  // expected-error at +1 {{'amdgpu.wmma' op source element types much match (except for fp8/bf8)}}
+  // expected-error at +1 {{'amdgpu.wmma' op source element types must match (except for fp8/bf8)}}
   %0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xi8>, vector<8xi4>, vector<8xi32>
   func.return %0 : vector<8xi32>
 }

>From 7e678485324289a9dad5ab86999000db1e233562 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Tue, 28 Oct 2025 11:35:34 -0400
Subject: [PATCH 3/3] Move conversion logic two two helper functions

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 118 +++++++++++-------
 1 file changed, 70 insertions(+), 48 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 90e731c11da62..1eca43d96fe85 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -989,26 +989,17 @@ mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
                                  smfma.getN(), smfma.getK(), 1u, chipset);
 }
 
-/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
-/// if one exists. This includes checking to ensure the intrinsic is supported
-/// on the architecture you are compiling for.
-static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
-                                                  Chipset chipset) {
-  auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
-  auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
-  auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
-  Type elemSourceType = sourceVectorType.getElementType();
-  Type elemBSourceType = sourceBVectorType.getElementType();
-  Type elemDestType = destVectorType.getElementType();
-
-  const uint32_t k = wmma.getK();
-  const bool isRDNA3 = chipset.majorVersion == 11;
-  const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
+/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
+/// for RDNA3/4 architectures.
+static std::optional<StringRef>
+wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType,
+                      Type elemDestType, uint32_t k, bool isRDNA3) {
+  using fp8 = Float8E4M3FNType;
+  using bf8 = Float8E5M2Type;
 
+  // Handle k == 16 for RDNA3/4.
   if (k == 16) {
-    if (!isRDNA3 && !isRDNA4) // gfx1250 does not have any wmma ops with k=16.
-      return std::nullopt;
-
+    // Common patterns for RDNA3 and RDNA4.
     if (elemSourceType.isF16() && elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
     if (elemSourceType.isBF16() && elemDestType.isF32())
@@ -1019,22 +1010,15 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
       return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
     if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
       return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
-    if (chipset.majorVersion == 11) {
+
+    // RDNA3 specific patterns.
+    if (isRDNA3) {
       if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
         return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
-    }
-  }
-  if (isRDNA3)
-    return std::nullopt;
-
-  using fp8 = Float8E4M3FNType;
-  using bf8 = Float8E5M2Type;
-
-  // gfx12+
-  if (k == 16) {
-    if (!isRDNA4) // gfx1250 does not have any wmma ops with k=16.
       return std::nullopt;
+    }
 
+    // RDNA4 specific patterns (fp8/bf8).
     if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
         elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
@@ -1047,20 +1031,38 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
     if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
         elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
-
     if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
       return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
 
     return std::nullopt;
   }
-  if (k == 32) {
-    if (isRDNA4) {
-      if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
-        return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
-      return std::nullopt;
-    }
 
-    // gfx1250
+  // Handle k == 32 for RDNA4.
+  if (k == 32 && !isRDNA3) {
+    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
+      return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+  }
+
+  llvm_unreachable("Unsupported k value");
+}
+
+/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
+/// for the gfx1250 architecture.
+static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType,
+                                                         Type elemBSourceType,
+                                                         Type elemDestType,
+                                                         uint32_t k) {
+  using fp8 = Float8E4M3FNType;
+  using bf8 = Float8E5M2Type;
+
+  if (k == 4) {
+    if (elemSourceType.isF32() && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
+
+    return std::nullopt;
+  }
+
+  if (k == 32) {
     if (elemSourceType.isF16() && elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
     if (elemSourceType.isBF16() && elemDestType.isF32())
@@ -1073,16 +1075,6 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
     return std::nullopt;
   }
 
-  if (isRDNA4)
-    return std::nullopt;
-
-  // gfx1250
-  if (k == 4) {
-    if (elemSourceType.isF32() && elemDestType.isF32())
-      return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
-    return std::nullopt;
-  }
-
   if (k == 64) {
     if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
       if (elemDestType.isF32())
@@ -1142,6 +1134,36 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
 
     return std::nullopt;
   }
+
+  llvm_unreachable("Unsupported k value");
+}
+
+/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
+/// if one exists. This includes checking to ensure the intrinsic is supported
+/// on the architecture you are compiling for.
+static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
+                                                  Chipset chipset) {
+  auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
+  auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
+  auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
+  Type elemSourceType = sourceVectorType.getElementType();
+  Type elemBSourceType = sourceBVectorType.getElementType();
+  Type elemDestType = destVectorType.getElementType();
+
+  const uint32_t k = wmma.getK();
+  const bool isRDNA3 = chipset.majorVersion == 11;
+  const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
+
+  // Handle RDNA3 and RDNA4.
+  if (isRDNA3 || isRDNA4)
+    return wmmaOpToIntrinsicRDNA(elemSourceType, elemBSourceType, elemDestType,
+                                 k, isRDNA3);
+
+  // Handle gfx1250.
+  if (chipset == Chipset{12, 5, 0})
+    return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType,
+                                    elemDestType, k);
+
   llvm_unreachable("unhandled WMMA case");
 }
 



More information about the Mlir-commits mailing list