[Mlir-commits] [mlir] [mlir][amdgpu][rocdl] Add gfx1250 wmma ops (PR #165064)

Jakub Kuderski llvmlistbot at llvm.org
Sat Oct 25 04:36:35 PDT 2025


https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/165064

>From 0f9044840d28d7dbe57f260bc9557251525c9dc7 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Fri, 24 Oct 2025 18:40:59 -0400
Subject: [PATCH 1/2] [mlir][amdgpu][rocdl] Add gfx1250 wmma ops

Update `amdgpu.wmma` op definition and implement amdgpu to rocdl
conversion for new variants.
---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td |  27 ++--
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 117 ++++++++++++++++--
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  |  10 +-
 .../AMDGPUToROCDL/wmma-gfx1250.mlir           |  89 +++++++++++++
 mlir/test/Dialect/AMDGPU/invalid.mlir         |  60 +++++++--
 mlir/test/Dialect/AMDGPU/ops.mlir             |  35 ++++++
 6 files changed, 306 insertions(+), 32 deletions(-)
 create mode 100644 mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 1437c39c69da2..f9c517bb53bb8 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -912,9 +912,10 @@ def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN
                                    VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
 def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
 // wmma
-def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
-                             VectorOfLengthAndType<[4, 8, 16], [I8, SI8, UI8]>,
-                             VectorOfLengthAndType<[4, 8], [F8E4M3FN, F8E5M2]>,
+def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[2], [F32]>,
+                             VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
+                             VectorOfLengthAndType<[4, 8, 16, 32], [I8, SI8, UI8]>,
+                             VectorOfLengthAndType<[4, 8, 32, 64], [F8E4M3FN, F8E5M2]>,
                              VectorOfLengthAndType<[4, 8, 16], [I<4>, SI<4>, UI<4>]>]>;
 def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
                               VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>;
@@ -992,7 +993,7 @@ def AMDGPU_WMMAOp :
     Arguments<(ins
                    ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$m,
                    ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n,
-                   ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$k,
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[4, 16, 32, 64, 128]>]>:$k,
                    WMMAInTypes:$sourceA,
                    WMMAInTypes:$sourceB,
                    WMMAOutTypes:$destC,
@@ -1005,8 +1006,14 @@ def AMDGPU_WMMAOp :
   let description = [{
     The `amdgpu.wmma` op is an MLIR wrapper around intrinsics for various `wmma`
     instructions in the AMDGPU architecture, which perform matrix multiplication.
-    Note that all wmma intrinsics have M=N=16 dimensions but vary by in allowed K
-    dimensions.
+
+    On gfx11/RDNA3, wmma intrinsics have M=N=K=16 dimensions.
+
+    On gfx12/RDNA4, wmma intrinsics have M=N=16 dimensions and support K=16 for
+    all element types, and K=32 for i4 sources.
+
+    On gfx1250, wmma intrinsics have M=N=16 and K dimensions of 4, 32, 64, or 128,
+    depending on the element types.
 
     On gfx11/RDNA3, emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16
     (or 16xbf16) vector containing only 8 valid values:
@@ -1022,7 +1029,13 @@ def AMDGPU_WMMAOp :
 
     Example:
     ```mlir
-      %0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<16xf16>, vector<16xf16>, vector<8xf16>
+      %0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<8xf16>, vector<8xf16>, vector<8xf16>
+
+      %1 = amdgpu.wmma 16x16x64 %matD * %matE + %matF : vector<32xi8>, vector<8xf32>, vector<8xf32>
+
+      %2 = amdgpu.wmma 16x16x128 %matG * %matH + %matI : vector<64xf4E2M1FN>, vector<64xf4E2M1FN>, vector<8xf32>
+
+      %3 = amdgpu.wmma 16x16x4 %matJ * %matK + %matL : vector<2xf32>, vector<2xf32>, vector<8xf32>
     ```
   }];
   let assemblyFormat = [{
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 478b6aaaec83a..90e731c11da62 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1002,8 +1002,13 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
   Type elemDestType = destVectorType.getElementType();
 
   const uint32_t k = wmma.getK();
+  const bool isRDNA3 = chipset.majorVersion == 11;
+  const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
 
   if (k == 16) {
+    if (!isRDNA3 && !isRDNA4) // gfx1250 does not have any wmma ops with k=16.
+      return std::nullopt;
+
     if (elemSourceType.isF16() && elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
     if (elemSourceType.isBF16() && elemDestType.isF32())
@@ -1019,34 +1024,124 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
         return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
     }
   }
-  if (chipset.majorVersion < 12)
+  if (isRDNA3)
     return std::nullopt;
 
+  using fp8 = Float8E4M3FNType;
+  using bf8 = Float8E5M2Type;
+
   // gfx12+
   if (k == 16) {
-    if (isa<Float8E4M3FNType>(elemSourceType) &&
-        isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
+    if (!isRDNA4) // gfx1250 does not have any wmma ops with k=16.
+      return std::nullopt;
+
+    if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
+        elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
-    if (isa<Float8E4M3FNType>(elemSourceType) &&
-        isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
+    if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
+        elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
-    if (isa<Float8E5M2Type>(elemSourceType) &&
-        isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
+    if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
+        elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
-    if (isa<Float8E5M2Type>(elemSourceType) &&
-        isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
+    if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
+        elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
+
     if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
       return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
 
     return std::nullopt;
   }
   if (k == 32) {
-    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
-      return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+    if (isRDNA4) {
+      if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
+        return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+      return std::nullopt;
+    }
+
+    // gfx1250
+    if (elemSourceType.isF16() && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
+    if (elemSourceType.isBF16() && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
+    if (elemSourceType.isF16() && elemDestType.isF16())
+      return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
+    if (elemSourceType.isBF16() && elemDestType.isBF16())
+      return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
+
+    return std::nullopt;
+  }
+
+  if (isRDNA4)
+    return std::nullopt;
+
+  // gfx1250
+  if (k == 4) {
+    if (elemSourceType.isF32() && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
     return std::nullopt;
   }
 
+  if (k == 64) {
+    if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
+    }
+    if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
+    }
+    if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
+    }
+    if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
+    }
+    if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
+      return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
+
+    return std::nullopt;
+  }
+
+  if (k == 128) {
+    if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
+    }
+    if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
+    }
+    if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
+    }
+    if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
+    }
+
+    return std::nullopt;
+  }
   llvm_unreachable("unhandled WMMA case");
 }
 
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 585b6dacfa648..b46568a423aa7 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -399,13 +399,15 @@ LogicalResult WMMAOp::verify() {
 
   if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
     return emitOpError(
-               "source element types much match (except for fp8) but have ")
+               "source element types much match (except for fp8/bf8) but have ")
            << sourceAType << " and " << sourceBType;
   }
 
-  if (!sourceAElemType.isInteger(4) && getK() != 16) {
-    return emitOpError("K dimension must be 16 for source element type ")
-           << sourceAElemType;
+  if (isSrcFloat) {
+    if (getClamp())
+      return emitOpError("clamp flag is not supported for float types");
+    if (getUnsignedA() || getUnsignedB())
+      return emitOpError("unsigned flags are not supported for float types");
   }
   return success();
 }
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
new file mode 100644
index 0000000000000..bcbdef040ebe3
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
@@ -0,0 +1,89 @@
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 --allow-unregistered-dialect | FileCheck %s
+
+// CHECK-LABEL: @wmma_k4
+func.func @wmma_k4(%arg0 : vector<2xf32>, %arg1 : vector<8xf32>) {
+  // CHECK: rocdl.wmma.f32.16x16x4.f32 %arg0, %arg0, %arg1
+  amdgpu.wmma 16x16x4 %arg0 * %arg0 + %arg1 : vector<2xf32>, vector<2xf32>, vector<8xf32>
+  func.return
+}
+
+// CHECK-LABEL: @wmma_k32
+func.func @wmma_k32(%arg0 : vector<16xf16>, %arg1 : vector<16xbf16>, %arg2 : vector<8xf32>,
+                    %arg3 : vector<8xf16>, %arg4 : vector<8xbf16>) {
+  // CHECK: rocdl.wmma.f32.16x16x32.f16 %arg0, %arg0, %arg2
+  amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x32.f16 %arg0, %arg0, {{.*}} : (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg3 : vector<16xf16>, vector<16xf16>, vector<8xf16>
+
+  // CHECK: rocdl.wmma.f32.16x16x32.bf16 {{.*}}, {{.*}}, %arg2
+  amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.bf16.16x16x32.bf16 {{.*}}, {{.*}}, {{.*}}, {{.*}} : (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1)
+  amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg4 : vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
+
+  func.return
+}
+
+// CHECK-LABEL: @wmma_k64
+func.func @wmma_k64(%arg0 : vector<32xi8>, %arg1 : vector<32xf8E4M3FN>, %arg2 : vector<32xf8E5M2>,
+                    %arg3 : vector<8xi32>, %arg4 : vector<8xf32>, %arg5 : vector<8xf16>) {
+  // CHECK: rocdl.wmma.i32.16x16x64.iu8 {{.*}}, {{.*}}, {{.*}}, {{.*}}, %arg3, {{.*}}
+  amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg3 {clamp} : vector<32xi8>, vector<32xi8>, vector<8xi32>
+
+  // CHECK: rocdl.wmma.f32.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg4
+  amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf16>
+
+  // CHECK: rocdl.wmma.f32.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg4
+  amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf16>
+
+  // CHECK: rocdl.wmma.f32.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg4
+  amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg4 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg5 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf16>
+
+  // CHECK: rocdl.wmma.f32.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg4
+  amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg4 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg5 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf16>
+
+  func.return
+}
+
+// CHECK-LABEL: @wmma_k128
+func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
+                     %arg2 : vector<8xf32>, %arg3 : vector<8xf16>) {
+  // CHECK: rocdl.wmma.f32.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg2
+  amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf16>
+
+  // CHECK: rocdl.wmma.f32.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg2
+  amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf16>
+
+  // CHECK: rocdl.wmma.f32.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg2
+  amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg2 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg3 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf16>
+
+  // CHECK: rocdl.wmma.f32.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg2
+  amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg2 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg3 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf16>
+
+  func.return
+}
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index b27dfdf43ca59..0511334b5b51f 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -156,14 +156,6 @@ func.func @wmma_no_k_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector
 
 // -----
 
-func.func @wmma_wrong_m_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
-  // expected-error at +1 {{'amdgpu.wmma' op attribute 'm' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16}}}
-  %0 = amdgpu.wmma 32x16x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
-  func.return %0 : vector<8xi32>
-}
-
-// -----
-
 func.func @wmma_wrong_n_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
   // expected-error at +1 {{'amdgpu.wmma' op attribute 'n' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16}}}
   %0 = amdgpu.wmma 16x32x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
@@ -173,14 +165,62 @@ func.func @wmma_wrong_n_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vec
 // -----
 
 func.func @wmma_wrong_k_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
-  // expected-error at +1 {{'amdgpu.wmma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}}
+  // expected-error at +1 {{'amdgpu.wmma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {4, 16, 32, 64, 128}}}
   %0 = amdgpu.wmma 16x16x24 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
   func.return %0 : vector<8xi32>
 }
 
 // -----
 
-// Missinng `resetOffset`
+func.func @wmma_source_length_mismatch(%arg0 : vector<8xf16>, %arg1 : vector<16xf16>, %arg2 : vector<8xf32>) -> vector<8xf32> {
+  // expected-error at +1 {{'amdgpu.wmma' op source vectors have different lengths}}
+  %0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xf16>, vector<16xf16>, vector<8xf32>
+  func.return %0 : vector<8xf32>
+}
+
+// -----
+
+func.func @wmma_mismatched_float_types(%arg0 : vector<8xf16>, %arg1 : vector<8xbf16>, %arg2 : vector<8xf32>) -> vector<8xf32> {
+  // expected-error at +1 {{'amdgpu.wmma' op source element types much match (except for fp8/bf8)}}
+  %0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
+  func.return %0 : vector<8xf32>
+}
+
+// -----
+
+func.func @wmma_mismatched_int_types(%arg0 : vector<8xi8>, %arg1 : vector<8xi4>, %arg2 : vector<8xi32>) -> vector<8xi32> {
+  // expected-error at +1 {{'amdgpu.wmma' op source element types much match (except for fp8/bf8)}}
+  %0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xi8>, vector<8xi4>, vector<8xi32>
+  func.return %0 : vector<8xi32>
+}
+
+// -----
+
+func.func @wmma_clamp_float(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+  // expected-error at +1 {{'amdgpu.wmma' op clamp flag is not supported for float types}}
+  %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {clamp} : vector<8xf16>, vector<8xf16>, vector<8xf32>
+  func.return %0 : vector<8xf32>
+}
+
+// -----
+
+func.func @wmma_unsignedA_float(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+  // expected-error at +1 {{'amdgpu.wmma' op unsigned flags are not supported for float types}}
+  %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {unsignedA} : vector<8xf16>, vector<8xf16>, vector<8xf32>
+  func.return %0 : vector<8xf32>
+}
+
+// -----
+
+func.func @wmma_unsignedB_float(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+  // expected-error at +1 {{'amdgpu.wmma' op unsigned flags are not supported for float types}}
+  %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {unsignedB} : vector<8xf16>, vector<8xf16>, vector<8xf32>
+  func.return %0 : vector<8xf32>
+}
+
+// -----
+
+// Missing `resetOffset`
 func.func @fat_raw_buffer_cast_stripped_offset(%m: memref<8xi32, strided<[1], offset: ?>, #gpu.address_space<global>>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
   // expected-error at +1 {{'amdgpu.fat_raw_buffer_cast' op expected result type to be 'memref<8xi32, strided<[1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>>' but got 'memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>'}}
   %ret = amdgpu.fat_raw_buffer_cast %m : memref<8xi32, strided<[1], offset: ?>, #gpu.address_space<global>> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index d0bf0d73c4875..17d1430d6193c 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -586,6 +586,41 @@ func.func @wmma_i32_16x16x32_i4(%arg0 : vector<16xi4>, %arg1 : vector<8xi32>) ->
   func.return %0 : vector<8xi32>
 }
 
+// CHECK-LABEL: func @wmma_f32_16x16x4_f32
+func.func @wmma_f32_16x16x4_f32(%arg0 : vector<2xf32>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+  // CHECK: amdgpu.wmma 16x16x4
+  %0 = amdgpu.wmma 16x16x4 %arg0 * %arg0 + %arg1 : vector<2xf32>, vector<2xf32>, vector<8xf32>
+  func.return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @wmma_f32_16x16x64_f8
+func.func @wmma_f32_16x16x64_f8(%arg0 : vector<32xf8E4M3FN>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+  // CHECK: amdgpu.wmma 16x16x64
+  %0 = amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg1 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf32>
+  func.return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @wmma_f32_16x16x64_bf8
+func.func @wmma_f32_16x16x64_bf8(%arg0 : vector<32xf8E5M2>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+  // CHECK: amdgpu.wmma 16x16x64
+  %0 = amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg1 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf32>
+  func.return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @wmma_f16_16x16x64_bf8
+func.func @wmma_f16_16x16x64_bf8(%arg0 : vector<32xf8E5M2>, %arg1 : vector<8xf16>) -> vector<8xf16> {
+  // CHECK: amdgpu.wmma 16x16x64
+  %0 = amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg1 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf16>
+  func.return %0 : vector<8xf16>
+}
+
+// CHECK-LABEL: func @wmma_f16_16x16x64_f8
+func.func @wmma_f16_16x16x64_f8(%arg0 : vector<32xf8E4M3FN>, %arg1 : vector<8xf16>) -> vector<8xf16> {
+  // CHECK: amdgpu.wmma 16x16x64
+  %0 = amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg1 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf16>
+  func.return %0 : vector<8xf16>
+}
+
 // CHECK-LABEL: func @swizzle_bitmode
 func.func @swizzle_bitmode(%arg0 : f32) -> f32 {
   // CHECK: amdgpu.swizzle_bitmode

>From f58a79acff779598c8c409d15af0c2c64b1096c7 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sat, 25 Oct 2025 07:36:14 -0400
Subject: [PATCH 2/2] Fix typo

---
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 2 +-
 mlir/test/Dialect/AMDGPU/invalid.mlir        | 4 ++--
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index b46568a423aa7..df955fc90b45f 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -399,7 +399,7 @@ LogicalResult WMMAOp::verify() {
 
   if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
     return emitOpError(
-               "source element types much match (except for fp8/bf8) but have ")
+               "source element types must match (except for fp8/bf8) but have ")
            << sourceAType << " and " << sourceBType;
   }
 
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 0511334b5b51f..b87b8d217aecc 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -181,7 +181,7 @@ func.func @wmma_source_length_mismatch(%arg0 : vector<8xf16>, %arg1 : vector<16x
 // -----
 
 func.func @wmma_mismatched_float_types(%arg0 : vector<8xf16>, %arg1 : vector<8xbf16>, %arg2 : vector<8xf32>) -> vector<8xf32> {
-  // expected-error at +1 {{'amdgpu.wmma' op source element types much match (except for fp8/bf8)}}
+  // expected-error at +1 {{'amdgpu.wmma' op source element types must match (except for fp8/bf8)}}
   %0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
   func.return %0 : vector<8xf32>
 }
@@ -189,7 +189,7 @@ func.func @wmma_mismatched_float_types(%arg0 : vector<8xf16>, %arg1 : vector<8xb
 // -----
 
 func.func @wmma_mismatched_int_types(%arg0 : vector<8xi8>, %arg1 : vector<8xi4>, %arg2 : vector<8xi32>) -> vector<8xi32> {
-  // expected-error at +1 {{'amdgpu.wmma' op source element types much match (except for fp8/bf8)}}
+  // expected-error at +1 {{'amdgpu.wmma' op source element types must match (except for fp8/bf8)}}
   %0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xi8>, vector<8xi4>, vector<8xi32>
   func.return %0 : vector<8xi32>
 }



More information about the Mlir-commits mailing list