[Mlir-commits] [mlir] [mlir][amdgpu][rocdl] Add gfx1250 wmma ops (PR #165064)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 24 19:06:17 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-amdgpu

Author: Jakub Kuderski (kuhar)

<details>
<summary>Changes</summary>

Update `amdgpu.wmma` op definition and implement amdgpu to rocdl conversion for new variants.

---

Patch is 22.14 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/165064.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+20-7) 
- (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+106-11) 
- (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+6-4) 
- (added) mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir (+89) 
- (modified) mlir/test/Dialect/AMDGPU/invalid.mlir (+50-10) 
- (modified) mlir/test/Dialect/AMDGPU/ops.mlir (+35) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index d74abc22acd5e..99e10b231482f 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -912,9 +912,10 @@ def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN
                                    VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
 def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
 // wmma
-def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
-                             VectorOfLengthAndType<[4, 8, 16], [I8, SI8, UI8]>,
-                             VectorOfLengthAndType<[4, 8], [F8E4M3FN, F8E5M2]>,
+def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[2], [F32]>,
+                             VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
+                             VectorOfLengthAndType<[4, 8, 16, 32], [I8, SI8, UI8]>,
+                             VectorOfLengthAndType<[4, 8, 32, 64], [F8E4M3FN, F8E5M2]>,
                              VectorOfLengthAndType<[4, 8, 16], [I<4>, SI<4>, UI<4>]>]>;
 def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
                               VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>;
@@ -990,7 +991,7 @@ def AMDGPU_WMMAOp :
     Arguments<(ins
                    ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$m,
                    ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n,
-                   ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$k,
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[4, 16, 32, 64, 128]>]>:$k,
                    WMMAInTypes:$sourceA,
                    WMMAInTypes:$sourceB,
                    WMMAOutTypes:$destC,
@@ -1003,8 +1004,14 @@ def AMDGPU_WMMAOp :
   let description = [{
     The `amdgpu.wmma` op is an MLIR wrapper around intrinsics for various `wmma`
     instructions in the AMDGPU architecture, which perform matrix multiplication.
-    Note that all wmma intrinsics have M=N=16 dimensions but vary by in allowed K
-    dimensions.
+
+    On gfx11/RDNA3, wmma intrinsics have M=N=K=16 dimensions.
+
+    On gfx12/RDNA4, wmma intrinsics have M=N=16 dimensions and support K=16 for
+    all element types, and K=32 for i4 sources.
+
+    On gfx1250, wmma intrinsics have M=N=16 and K dimensions of 4, 32, 64, or 128,
+    depending on the element types.
 
     On gfx11/RDNA3, emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16
     (or 16xbf16) vector containing only 8 valid values:
@@ -1020,7 +1027,13 @@ def AMDGPU_WMMAOp :
 
     Example:
     ```mlir
-      %0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<16xf16>, vector<16xf16>, vector<8xf16>
+      %0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<8xf16>, vector<8xf16>, vector<8xf16>
+
+      %1 = amdgpu.wmma 16x16x64 %matD * %matE + %matF : vector<32xi8>, vector<8xf32>, vector<8xf32>
+
+      %2 = amdgpu.wmma 16x16x128 %matG * %matH + %matI : vector<64xf4E2M1FN>, vector<64xf4E2M1FN>, vector<8xf32>
+
+      %3 = amdgpu.wmma 16x16x4 %matJ * %matK + %matL : vector<2xf32>, vector<2xf32>, vector<8xf32>
     ```
   }];
   let assemblyFormat = [{
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 478b6aaaec83a..90e731c11da62 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1002,8 +1002,13 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
   Type elemDestType = destVectorType.getElementType();
 
   const uint32_t k = wmma.getK();
+  const bool isRDNA3 = chipset.majorVersion == 11;
+  const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
 
   if (k == 16) {
+    if (!isRDNA3 && !isRDNA4) // gfx1250 does not have any wmma ops with k=16.
+      return std::nullopt;
+
     if (elemSourceType.isF16() && elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
     if (elemSourceType.isBF16() && elemDestType.isF32())
@@ -1019,34 +1024,124 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
         return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
     }
   }
-  if (chipset.majorVersion < 12)
+  if (isRDNA3)
     return std::nullopt;
 
+  using fp8 = Float8E4M3FNType;
+  using bf8 = Float8E5M2Type;
+
   // gfx12+
   if (k == 16) {
-    if (isa<Float8E4M3FNType>(elemSourceType) &&
-        isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
+    if (!isRDNA4) // gfx1250 does not have any wmma ops with k=16.
+      return std::nullopt;
+
+    if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
+        elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
-    if (isa<Float8E4M3FNType>(elemSourceType) &&
-        isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
+    if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
+        elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
-    if (isa<Float8E5M2Type>(elemSourceType) &&
-        isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
+    if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
+        elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
-    if (isa<Float8E5M2Type>(elemSourceType) &&
-        isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
+    if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
+        elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
+
     if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
       return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
 
     return std::nullopt;
   }
   if (k == 32) {
-    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
-      return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+    if (isRDNA4) {
+      if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
+        return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+      return std::nullopt;
+    }
+
+    // gfx1250
+    if (elemSourceType.isF16() && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
+    if (elemSourceType.isBF16() && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
+    if (elemSourceType.isF16() && elemDestType.isF16())
+      return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
+    if (elemSourceType.isBF16() && elemDestType.isBF16())
+      return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
+
+    return std::nullopt;
+  }
+
+  if (isRDNA4)
+    return std::nullopt;
+
+  // gfx1250
+  if (k == 4) {
+    if (elemSourceType.isF32() && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
     return std::nullopt;
   }
 
+  if (k == 64) {
+    if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
+    }
+    if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
+    }
+    if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
+    }
+    if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
+    }
+    if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
+      return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
+
+    return std::nullopt;
+  }
+
+  if (k == 128) {
+    if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
+    }
+    if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
+    }
+    if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
+    }
+    if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
+    }
+
+    return std::nullopt;
+  }
   llvm_unreachable("unhandled WMMA case");
 }
 
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 4c4965e67676e..7b4e248dcf5e4 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -399,13 +399,15 @@ LogicalResult WMMAOp::verify() {
 
   if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
     return emitOpError(
-               "source element types much match (except for fp8) but have ")
+               "source element types much match (except for fp8/bf8) but have ")
            << sourceAType << " and " << sourceBType;
   }
 
-  if (!sourceAElemType.isInteger(4) && getK() != 16) {
-    return emitOpError("K dimension must be 16 for source element type ")
-           << sourceAElemType;
+  if (isSrcFloat) {
+    if (getClamp())
+      return emitOpError("clamp flag is not supported for float types");
+    if (getUnsignedA() || getUnsignedB())
+      return emitOpError("unsigned flags are not supported for float types");
   }
   return success();
 }
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
new file mode 100644
index 0000000000000..bcbdef040ebe3
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
@@ -0,0 +1,89 @@
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 --allow-unregistered-dialect | FileCheck %s
+
+// CHECK-LABEL: @wmma_k4
+func.func @wmma_k4(%arg0 : vector<2xf32>, %arg1 : vector<8xf32>) {
+  // CHECK: rocdl.wmma.f32.16x16x4.f32 %arg0, %arg0, %arg1
+  amdgpu.wmma 16x16x4 %arg0 * %arg0 + %arg1 : vector<2xf32>, vector<2xf32>, vector<8xf32>
+  func.return
+}
+
+// CHECK-LABEL: @wmma_k32
+func.func @wmma_k32(%arg0 : vector<16xf16>, %arg1 : vector<16xbf16>, %arg2 : vector<8xf32>,
+                    %arg3 : vector<8xf16>, %arg4 : vector<8xbf16>) {
+  // CHECK: rocdl.wmma.f32.16x16x32.f16 %arg0, %arg0, %arg2
+  amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x32.f16 %arg0, %arg0, {{.*}} : (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg3 : vector<16xf16>, vector<16xf16>, vector<8xf16>
+
+  // CHECK: rocdl.wmma.f32.16x16x32.bf16 {{.*}}, {{.*}}, %arg2
+  amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.bf16.16x16x32.bf16 {{.*}}, {{.*}}, {{.*}}, {{.*}} : (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1)
+  amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg4 : vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
+
+  func.return
+}
+
+// CHECK-LABEL: @wmma_k64
+func.func @wmma_k64(%arg0 : vector<32xi8>, %arg1 : vector<32xf8E4M3FN>, %arg2 : vector<32xf8E5M2>,
+                    %arg3 : vector<8xi32>, %arg4 : vector<8xf32>, %arg5 : vector<8xf16>) {
+  // CHECK: rocdl.wmma.i32.16x16x64.iu8 {{.*}}, {{.*}}, {{.*}}, {{.*}}, %arg3, {{.*}}
+  amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg3 {clamp} : vector<32xi8>, vector<32xi8>, vector<8xi32>
+
+  // CHECK: rocdl.wmma.f32.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg4
+  amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf16>
+
+  // CHECK: rocdl.wmma.f32.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg4
+  amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf16>
+
+  // CHECK: rocdl.wmma.f32.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg4
+  amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg4 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg5 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf16>
+
+  // CHECK: rocdl.wmma.f32.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg4
+  amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg4 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg5 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf16>
+
+  func.return
+}
+
+// CHECK-LABEL: @wmma_k128
+func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
+                     %arg2 : vector<8xf32>, %arg3 : vector<8xf16>) {
+  // CHECK: rocdl.wmma.f32.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg2
+  amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf16>
+
+  // CHECK: rocdl.wmma.f32.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg2
+  amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf16>
+
+  // CHECK: rocdl.wmma.f32.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg2
+  amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg2 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg3 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf16>
+
+  // CHECK: rocdl.wmma.f32.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg2
+  amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg2 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg3 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf16>
+
+  func.return
+}
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 6a2518a40cc99..14a3fe3af60ec 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -144,14 +144,6 @@ func.func @wmma_no_k_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector
 
 // -----
 
-func.func @wmma_wrong_m_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
-  // expected-error at +1 {{'amdgpu.wmma' op attribute 'm' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16}}}
-  %0 = amdgpu.wmma 32x16x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
-  func.return %0 : vector<8xi32>
-}
-
-// -----
-
 func.func @wmma_wrong_n_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
   // expected-error at +1 {{'amdgpu.wmma' op attribute 'n' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16}}}
   %0 = amdgpu.wmma 16x32x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
@@ -161,14 +153,62 @@ func.func @wmma_wrong_n_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vec
 // -----
 
 func.func @wmma_wrong_k_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
-  // expected-error at +1 {{'amdgpu.wmma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}}
+  // expected-error at +1 {{'amdgpu.wmma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {4, 16, 32, 64, 128}}}
   %0 = amdgpu.wmma 16x16x24 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
   func.return %0 : vector<8xi32>
 }
 
 // -----
 
-// Missinng `resetOffset`
+func.func @wmma_source_length_mismatch(%arg0 : vector<8xf16>, %arg1 : vector<16xf16>, %arg2 : vector<8xf32>) -> vector<8xf32> {
+  // expected-error at +1 {{'amdgpu.wmma' op source vectors have different lengths}}
+  %0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xf16>, vector<16xf16>, vector<8xf32>
+  func.return %0 : vector<8xf32>
+}
+
+// -----
+
+func.func @wmma_mismatched_float_types(%arg0 : vector<8xf16>, %arg1 : vector<8xbf16>, %arg2 : vector<8xf32>) -> vector<8xf32> {
+  // expected-error at +1 {{'amdgpu.wmma' op source element types much match (except for fp8/bf8)}}
+  %0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
+  func.return %0 : vector<8xf32>
+}
+
+// -----
+
+func.func @wmma_mismatched_int_types(%arg0 : vector<8xi8>, %arg1 : vector<8xi4>, %arg2 : vector<8xi32>) -> vector<8xi32> {
+  // expected-error at +1 {{'amdgpu.wmma' op source element types much match (except for fp8/bf8)}}
+  %0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xi8>, vector<8xi4>, vector<8xi32>
+  func.return %0 : vector<8xi32>
+}
+
+// -----
+
+func.func @wmma_clamp_float(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+  // expected-error at +1 {{'amdgpu.wmma' op clamp flag is not supported for float types}}
+  %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {clamp} : vector<8xf16>, vector<8xf16>, vector<8xf32>
+  func.return %0 : vector<8xf32>
+}
+
+// -----
+
+func.func @wmma_unsignedA_float(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+  // expected-error at +1 {{'amdgpu.wmma' op unsigned flags are not supported for float types}}
+  %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {unsignedA} : vector<8xf16>, vector<8xf16>, vector<8xf32>
+  func.return %0 : vector<8xf32>
+}
+
+// -----
+
+func.func @wmma_unsignedB_float(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+  // expected-error at +1 {{'amdgpu.wmma' op unsigned flags are not supported for float types}}
+  %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {unsignedB} : vector<8xf16>, vector<8xf16>, vector<8xf32>
+  func.return %0 : vector<8xf32>
+}
+
+// -----
+
+// Missing `resetOffset`
 func.func @fat_raw_buffer_cast_stripped_offset(%m: memref<8xi32, strided<[1], offset: ?>, #gpu.address_space<global>>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
   // expected-error at +1 {{'amdgpu.fat_raw_buffer_cast' op expected result type to be 'memref<8xi32, strided<[1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>>' but got 'memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>'}}
   %ret = amdgpu.fat_raw_buffer_cast %m : memref<8xi32, strided<...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/165064


More information about the Mlir-commits mailing list