[Mlir-commits] [mlir] [mlir][amdgpu] Add explicit intrinsic shape to wmma (PR #164920)

Jakub Kuderski llvmlistbot at llvm.org
Fri Oct 24 07:20:28 PDT 2025


https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/164920

>From 2d9f5fe167dbb2686bc67328b0ddac76ea579b6e Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Thu, 23 Oct 2025 20:12:41 -0400
Subject: [PATCH 1/3] [mlir][amdgpu] Add explicit intrinsic shape to wmma

This is in preparation for adding support for gfx1250 wmma intrinsics
that include much more possible shapes.

Instead of guessing the wave32/wave64 mode based on element types and
vector sizes, require the intrinsic shapes to be set explicitly as
attributes.
---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 45 +++++++-----
 .../mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h    | 25 ++++++-
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 70 +++++++++++--------
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  | 61 +++++++++-------
 .../{wmma.mlir => wmma-gfx11.mlir}            | 27 +++----
 .../Conversion/AMDGPUToROCDL/wmma-gfx12.mlir  | 46 ++++++------
 mlir/test/Dialect/AMDGPU/invalid.mlir         | 46 +++++++++++-
 mlir/test/Dialect/AMDGPU/ops.mlir             | 15 ++--
 8 files changed, 218 insertions(+), 117 deletions(-)
 rename mlir/test/Conversion/AMDGPUToROCDL/{wmma.mlir => wmma-gfx11.mlir} (59%)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 7184de93bfacb..3a808ff3a01e4 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -912,12 +912,10 @@ def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN
                                    VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
 def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
 // wmma
-def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<
-                             [4, 8, 16],
-                             [F16, BF16,
-                              I8, SI8, UI8,
-                              I<4>, SI<4>, UI<4>,
-                              F8E4M3FN, F8E5M2]>]>;
+def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
+                             VectorOfLengthAndType<[4, 8, 16], [I8, SI8, UI8]>,
+                             VectorOfLengthAndType<[4, 8], [F8E4M3FN, F8E5M2]>,
+                             VectorOfLengthAndType<[4, 8, 16], [I<4>, SI<4>, UI<4>]>]>;
 def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
                               VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>;
 
@@ -968,6 +966,14 @@ def AMDGPU_MFMAOp :
 
     The negateA, negateB, and negateC flags are only supported for double-precision
     operations on gfx94x.
+
+    Example:
+    ```mlir
+      %0 = amdgpu.mfma %matA * %matB + %matC
+        { abid = 1 : i32, cbsz = 1 : i32,
+          m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32 }
+        blgp = bcast_second_32 : f32, f32, vector<32xf32>
+    ```
   }];
   let assemblyFormat = [{
     $sourceA `*` $sourceB `+` $destC
@@ -982,6 +988,9 @@ def AMDGPU_WMMAOp :
     AMDGPU_Op<"wmma", [AllTypesMatch<["destC", "destD"]>,
                        Pure]>,
     Arguments<(ins
+                   ConfinedAttr<I32Attr, [IntMinValue<16>, IntMaxValue<16>]>:$m,
+                   ConfinedAttr<I32Attr, [IntMinValue<16>, IntMaxValue<16>]>:$n,
+                   ConfinedAttr<I32Attr, [IntMinValue<16>, IntMaxValue<32>, IntPowerOf2]>:$k,
                    WMMAInTypes:$sourceA,
                    WMMAInTypes:$sourceB,
                    WMMAOutTypes:$destC,
@@ -990,28 +999,32 @@ def AMDGPU_WMMAOp :
                    UnitAttr:$unsignedB,
                    UnitAttr:$clamp)>,
     Results<(outs WMMAOutTypes: $destD)> {
-  let summary = "MLIR wrapper for RDNA3 wmma instructions";
+  let summary = "MLIR wrapper for wmma instructions";
   let description = [{
-    The `amdgpu.wmma` op is an MLIR wrapper around intrinsics
-    for various `wmma` instructions in the RDNA3 or RDNA4 architecture, which
-    perform a 16x16 * 16x16 matrix multiplication for different data types.
-    Note that in gfx12/RDNA4, there is also a 16x32 * 32x16 instruction for 4-bit
-    integer inputs.
+    The `amdgpu.wmma` op is an MLIR wrapper around intrinsics for various `wmma`
+    instructions in the AMDGPU architecture, which perform matrix multiplication.
+    Note that all wmma intrinsics have M=N=16 dimensions but vary by in allowed K
+    dimensions.
 
     On gfx11/RDNA3, emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16
     (or 16xbf16) vector containing only 8 valid values:
       - If `subwordOffset` is 0, then the output is stored at indices 0, 2, 4, ..., 14.
       - If `subwordOffset` is 1, then the output is stored at indices 1, 3, 5, ..., 15.
-    On gfx12/RDNA4, the result is instead returned as a vector<8 x f16/bf16> where
-    all values are valid and the `subwordOffset` must be `0`, as it cannot be used.
+    On gfx12/RDNA4 and gfx1250, the result is instead returned as vector where all
+    the values are valid and the `subwordOffset` must be `0`, as it cannot be used.
 
     `unsignedA` and `unsignedB` flag that the `int8` LLVM inputs are unsigned.
 
-    The `clamp` flag is used to saturate the output of type T to numeric_limits<T>::max()
+    The `clamp` flag is used to saturate the output of type T to `numeric_limits<T>::max()`
     in case of overflow.
+
+    Example:
+    ```mlir
+      %0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<16xf16>, vector<16xf16>, vector<8xf16>
+    ```
   }];
   let assemblyFormat = [{
-    $sourceA `*` $sourceB `+` $destC
+    custom<MNKDimensionList>($m, $n, $k) $sourceA `*` $sourceB `+` $destC
     attr-dict
     `:` type($sourceA) `,` type($sourceB) `,` type($destC)
   }];
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
index 3de57c923178a..b6fe61ff1afa2 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
@@ -7,7 +7,7 @@
 //===----------------------------------------------------------------------===//
 //
 // This file declares a dialect for MLIR wrappers around AMDGPU-specific
-// intrinssics and for other AMD GPU-specific functionality.
+// intrinsics and for other AMD GPU-specific functionality.
 //
 //===----------------------------------------------------------------------===//
 
@@ -26,6 +26,29 @@
 
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.h.inc"
 
+namespace mlir {
+/// Parser for the `custom<MNKDimensionList>` custom assembly format used by
+/// WMMAOp.
+ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m,
+                                  IntegerAttr &n, IntegerAttr &k);
+inline ParseResult parseMNKDimensionList(OpAsmParser &parser, Operation *,
+                                         IntegerAttr &m, IntegerAttr &n,
+                                         IntegerAttr &k) {
+  return parseMNKDimensionList(parser, m, n, k);
+}
+
+/// Printer for the `custom<MNKDimensionList>` custom assembly format used by
+/// WMMAOp.
+inline void printMNKDimensionList(OpAsmPrinter &printer, IntegerAttr m,
+                                  IntegerAttr n, IntegerAttr k) {
+  printer.printDimensionList(ArrayRef{m.getInt(), n.getInt(), k.getInt()});
+}
+inline void printMNKDimensionList(OpAsmPrinter &printer, Operation *,
+                                  IntegerAttr m, IntegerAttr n, IntegerAttr k) {
+  printMNKDimensionList(printer, m, n, k);
+}
+} // namespace mlir
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.h.inc"
 
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 9b154350cd913..478b6aaaec83a 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
@@ -993,28 +994,36 @@ mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
 /// on the architecture you are compiling for.
 static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
                                                   Chipset chipset) {
-  auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
-  auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType());
-  auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
-  auto elemSourceType = sourceVectorType.getElementType();
-  auto elemBSourceType = sourceBVectorType.getElementType();
-  auto elemDestType = destVectorType.getElementType();
-
-  if (elemSourceType.isF16() && elemDestType.isF32())
-    return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
-  if (elemSourceType.isBF16() && elemDestType.isF32())
-    return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
-  if (elemSourceType.isF16() && elemDestType.isF16())
-    return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
-  if (elemSourceType.isBF16() && elemDestType.isBF16())
-    return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
-  if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
-    return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
-  if (chipset.majorVersion == 11) {
-    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
-      return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
+  auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
+  auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
+  auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
+  Type elemSourceType = sourceVectorType.getElementType();
+  Type elemBSourceType = sourceBVectorType.getElementType();
+  Type elemDestType = destVectorType.getElementType();
+
+  const uint32_t k = wmma.getK();
+
+  if (k == 16) {
+    if (elemSourceType.isF16() && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
+    if (elemSourceType.isBF16() && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
+    if (elemSourceType.isF16() && elemDestType.isF16())
+      return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
+    if (elemSourceType.isBF16() && elemDestType.isBF16())
+      return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
+    if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
+      return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
+    if (chipset.majorVersion == 11) {
+      if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
+        return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
+    }
   }
-  if (chipset.majorVersion >= 12) {
+  if (chipset.majorVersion < 12)
+    return std::nullopt;
+
+  // gfx12+
+  if (k == 16) {
     if (isa<Float8E4M3FNType>(elemSourceType) &&
         isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
@@ -1027,17 +1036,18 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
     if (isa<Float8E5M2Type>(elemSourceType) &&
         isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
-    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
-      bool isWave64 = destVectorType.getNumElements() == 4;
-      // This is the ambiguous case. 8 inputs to the wave64 version means that
-      // we want the 16x16x32 version, but for wave32 they mean the short form.
-      bool has8Inputs = sourceVectorType.getNumElements() == 8;
-      if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
-        return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
       return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
-    }
+
+    return std::nullopt;
   }
-  return std::nullopt;
+  if (k == 32) {
+    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
+      return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+    return std::nullopt;
+  }
+
+  llvm_unreachable("unhandled WMMA case");
 }
 
 namespace {
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 61166db0ff210..eb40374d61303 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -360,45 +360,52 @@ LogicalResult ScaledExtPacked816Op::verify() {
 //===----------------------------------------------------------------------===//
 // WMMAOp
 //===----------------------------------------------------------------------===//
-LogicalResult WMMAOp::verify() {
-  Type sourceAType = getSourceA().getType();
-  Type sourceBType = getSourceB().getType();
-  Type destType = getDestC().getType();
 
-  VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
-  VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType);
-  VectorType destVectorType = dyn_cast<VectorType>(destType);
+ParseResult mlir::parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m,
+                                        IntegerAttr &n, IntegerAttr &k) {
+  SmallVector<int64_t, 3> dimensions;
+  if (parser.parseDimensionList(dimensions, false, false))
+    return failure();
+  if (dimensions.size() != 3)
+    return parser.emitError(parser.getCurrentLocation())
+           << "expected 3 dimensions in MNK dimension list";
 
-  Type sourceAElemType = sourceVectorAType.getElementType();
-  Type sourceBElemType = sourceVectorBType.getElementType();
-  Type destElemType = destVectorType.getElementType();
+  m = parser.getBuilder().getI32IntegerAttr(dimensions[0]);
+  n = parser.getBuilder().getI32IntegerAttr(dimensions[1]);
+  k = parser.getBuilder().getI32IntegerAttr(dimensions[2]);
+  return success();
+}
 
-  if (sourceVectorAType.getNumElements() !=
-      sourceVectorBType.getNumElements()) {
+LogicalResult WMMAOp::verify() {
+  auto sourceAType = cast<VectorType>(getSourceA().getType());
+  auto sourceBType = cast<VectorType>(getSourceB().getType());
+  auto destType = cast<VectorType>(getDestC().getType());
+
+  Type sourceAElemType = sourceAType.getElementType();
+  Type sourceBElemType = sourceBType.getElementType();
+  if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
     return emitOpError("source vectors have different lengths: ")
-           << sourceVectorAType << " vs. " << sourceVectorBType;
+           << sourceAType << " vs. " << sourceBType;
   }
 
-  bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
-  bool isSrcFloat =
-      isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
-          sourceAElemType);
-
-  if (isDestFloat && !isSrcFloat) {
-    return emitOpError("Expected float sources with float destination");
-  }
+  bool isDestFloat = destType.getElementType().isFloat();
+  bool isSrcFloat = sourceAElemType.isFloat();
 
-  if (!isDestFloat && isSrcFloat) {
-    return emitOpError("Expected int sources with int destination");
-  }
+  if (isDestFloat && !isSrcFloat)
+    return emitOpError("expected float sources with float destination");
+  if (!isDestFloat && isSrcFloat)
+    return emitOpError("expected int sources with int destination");
 
-  if (sourceAElemType != sourceBElemType &&
-      !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
-        isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
+  if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
     return emitOpError(
                "source element types much match (except for fp8) but have ")
            << sourceAType << " and " << sourceBType;
   }
+
+  if (!sourceAElemType.isInteger(4) && getK() != 16) {
+    return emitOpError("K dimension must be 16 for source element type ")
+           << sourceAElemType;
+  }
   return success();
 }
 
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
similarity index 59%
rename from mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
rename to mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
index 638a7c3f8c1c5..d1301d0089220 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
@@ -1,35 +1,36 @@
-// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s
+
 // CHECK-LABEL: @wmma_to_rocdl
 func.func @wmma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : vector<4xf32>,
                          %arg3 : vector<16xbf16>, %arg4 : vector<8xf16>, %arg5 : vector<8xbf16>,
                          %arg6 : vector<16xi8>, %arg7 : vector<8xi32>, %arg8 : vector<4xi32>,
                          %arg9 : vector<16xui8>, %arg10 : vector<16xi4>, %arg11 : vector<8xi4>) {
   // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
-  amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32>
+  amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32>
   // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32>
-  amdgpu.wmma %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32>
+  amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32>
   // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xf32>) -> vector<8xf32>
-  amdgpu.wmma %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
+  amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
   // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<4xf32>) -> vector<4xf32>
-  amdgpu.wmma %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32>
+  amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32>
   // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
-  amdgpu.wmma %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
+  amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
   // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16>
-  amdgpu.wmma %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
+  amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
   // CHECK: %[[raw_bf16x16:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16>
   // CHECK-NEXT: llvm.bitcast %[[raw_bf16x16]] : vector<16xi16> to vector<16xbf16>
-  amdgpu.wmma %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
+  amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
   // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16>
   // CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16>
-  amdgpu.wmma %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
+  amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
   // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32>
-  amdgpu.wmma %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32>
+  amdgpu.wmma 16x16x16 %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32>
   // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32>
-  amdgpu.wmma %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32>
+  amdgpu.wmma 16x16x16 %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32>
   // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
-  amdgpu.wmma %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32>
+  amdgpu.wmma 16x16x16 %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32>
   // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
-  amdgpu.wmma %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32>
+  amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32>
 
   func.return
 }
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
index 94a1b78d5f040..b897323340402 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 --allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1200 --allow-unregistered-dialect | FileCheck %s
 // CHECK-LABEL: @wmma_to_rocdl
 func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>,
                          %arg2 : vector<8xf32>, %arg3 : vector<4xf32>,
@@ -9,60 +9,60 @@ func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>,
                          %arg12 : vector<8xi32>, %arg13 : vector<4xi32>,
                          %arg14 : vector<16xi4>, %arg15 : vector<8xi4>, %arg16 : vector<4xi4>) {
   // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<8xf32>) -> vector<8xf32>
-  amdgpu.wmma %arg0 * %arg0 + %arg2 : vector<8xf16>, vector<8xf16>, vector<8xf32>
+  amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg2 : vector<8xf16>, vector<8xf16>, vector<8xf32>
   // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf32>) -> vector<4xf32>
-  amdgpu.wmma %arg1 * %arg1 + %arg3 : vector<4xf16>, vector<4xf16>, vector<4xf32>
+  amdgpu.wmma 16x16x16 %arg1 * %arg1 + %arg3 : vector<4xf16>, vector<4xf16>, vector<4xf32>
 
   // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<8xf32>) -> vector<8xf32>
-  amdgpu.wmma %arg4 * %arg4 + %arg2 : vector<8xbf16>, vector<8xbf16>, vector<8xf32>
+  amdgpu.wmma 16x16x16 %arg4 * %arg4 + %arg2 : vector<8xbf16>, vector<8xbf16>, vector<8xf32>
   // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xf32>) -> vector<4xf32>
-  amdgpu.wmma %arg5 * %arg5 + %arg3 : vector<4xbf16>, vector<4xbf16>, vector<4xf32>
+  amdgpu.wmma 16x16x16 %arg5 * %arg5 + %arg3 : vector<4xbf16>, vector<4xbf16>, vector<4xf32>
 
   // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<8xf16>, i1) -> vector<8xf16>
-  amdgpu.wmma %arg0 * %arg0 + %arg0 : vector<8xf16>, vector<8xf16>, vector<8xf16>
+  amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg0 : vector<8xf16>, vector<8xf16>, vector<8xf16>
   // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf16>, i1) -> vector<4xf16>
-  amdgpu.wmma %arg1 * %arg1 + %arg1 : vector<4xf16>, vector<4xf16>, vector<4xf16>
+  amdgpu.wmma 16x16x16 %arg1 * %arg1 + %arg1 : vector<4xf16>, vector<4xf16>, vector<4xf16>
 
   // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<8xi16>, i1) -> vector<8xi16>
   // CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16>
-  amdgpu.wmma %arg4 * %arg4 + %arg4 : vector<8xbf16>, vector<8xbf16>, vector<8xbf16>
+  amdgpu.wmma 16x16x16 %arg4 * %arg4 + %arg4 : vector<8xbf16>, vector<8xbf16>, vector<8xbf16>
   // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xi16>, i1) -> vector<4xi16>
-  amdgpu.wmma %arg5 * %arg5 + %arg5 : vector<4xbf16>, vector<4xbf16>, vector<4xbf16>
+  amdgpu.wmma 16x16x16 %arg5 * %arg5 + %arg5 : vector<4xbf16>, vector<4xbf16>, vector<4xbf16>
 
   // CHECK: rocdl.wmma.f32.16x16x16.fp8_fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
-  amdgpu.wmma %arg6 * %arg6 + %arg2 : vector<8xf8E4M3FN>, vector<8xf8E4M3FN>, vector<8xf32>
+  amdgpu.wmma 16x16x16 %arg6 * %arg6 + %arg2 : vector<8xf8E4M3FN>, vector<8xf8E4M3FN>, vector<8xf32>
   // CHECK: rocdl.wmma.f32.16x16x16.fp8_fp8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32>
-  amdgpu.wmma %arg7 * %arg7 + %arg3 : vector<4xf8E4M3FN>, vector<4xf8E4M3FN>, vector<4xf32>
+  amdgpu.wmma 16x16x16 %arg7 * %arg7 + %arg3 : vector<4xf8E4M3FN>, vector<4xf8E4M3FN>, vector<4xf32>
 
   // CHECK: rocdl.wmma.f32.16x16x16.fp8_bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
-  amdgpu.wmma %arg6 * %arg8 + %arg2 : vector<8xf8E4M3FN>, vector<8xf8E5M2>, vector<8xf32>
+  amdgpu.wmma 16x16x16 %arg6 * %arg8 + %arg2 : vector<8xf8E4M3FN>, vector<8xf8E5M2>, vector<8xf32>
   // CHECK: rocdl.wmma.f32.16x16x16.fp8_bf8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32>
-  amdgpu.wmma %arg7 * %arg9 + %arg3 : vector<4xf8E4M3FN>, vector<4xf8E5M2>, vector<4xf32>
+  amdgpu.wmma 16x16x16 %arg7 * %arg9 + %arg3 : vector<4xf8E4M3FN>, vector<4xf8E5M2>, vector<4xf32>
 
   // CHECK: rocdl.wmma.f32.16x16x16.bf8_bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
-  amdgpu.wmma %arg8 * %arg8 + %arg2 : vector<8xf8E5M2>, vector<8xf8E5M2>, vector<8xf32>
+  amdgpu.wmma 16x16x16 %arg8 * %arg8 + %arg2 : vector<8xf8E5M2>, vector<8xf8E5M2>, vector<8xf32>
   // CHECK: rocdl.wmma.f32.16x16x16.bf8_bf8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32>
-  amdgpu.wmma %arg9 * %arg9 + %arg3 : vector<4xf8E5M2>, vector<4xf8E5M2>, vector<4xf32>
+  amdgpu.wmma 16x16x16 %arg9 * %arg9 + %arg3 : vector<4xf8E5M2>, vector<4xf8E5M2>, vector<4xf32>
 
   // CHECK: rocdl.wmma.f32.16x16x16.bf8_fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
-  amdgpu.wmma %arg8 * %arg6 + %arg2 : vector<8xf8E5M2>, vector<8xf8E4M3FN>, vector<8xf32>
+  amdgpu.wmma 16x16x16 %arg8 * %arg6 + %arg2 : vector<8xf8E5M2>, vector<8xf8E4M3FN>, vector<8xf32>
   // CHECK: rocdl.wmma.f32.16x16x16.bf8_fp8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32>
-  amdgpu.wmma %arg9 * %arg7 + %arg3 : vector<4xf8E5M2>, vector<4xf8E4M3FN>, vector<4xf32>
+  amdgpu.wmma 16x16x16 %arg9 * %arg7 + %arg3 : vector<4xf8E5M2>, vector<4xf8E4M3FN>, vector<4xf32>
 
   // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
-  amdgpu.wmma %arg10 * %arg10 + %arg12 {clamp} : vector<8xi8>, vector<8xi8>, vector<8xi32>
+  amdgpu.wmma 16x16x16 %arg10 * %arg10 + %arg12 {clamp} : vector<8xi8>, vector<8xi8>, vector<8xi32>
   // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
-  amdgpu.wmma %arg11 * %arg11 + %arg13 {unsignedA, unsignedB, clamp}: vector<4xi8>, vector<4xi8>, vector<4xi32>
+  amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg13 {unsignedA, unsignedB, clamp}: vector<4xi8>, vector<4xi8>, vector<4xi32>
 
   // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
-  amdgpu.wmma %arg14 * %arg14 + %arg12 {clamp} : vector<16xi4>, vector<16xi4>, vector<8xi32>
+  amdgpu.wmma 16x16x32 %arg14 * %arg14 + %arg12 {clamp} : vector<16xi4>, vector<16xi4>, vector<8xi32>
   // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
-  amdgpu.wmma %arg15 * %arg15 + %arg13 {clamp} : vector<8xi4>, vector<8xi4>, vector<4xi32>
+  amdgpu.wmma 16x16x32 %arg15 * %arg15 + %arg13 {clamp} : vector<8xi4>, vector<8xi4>, vector<4xi32>
 
   // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<8xi32>, i1) -> vector<8xi32>
-  amdgpu.wmma %arg15 * %arg15 + %arg12 {clamp} : vector<8xi4>, vector<8xi4>, vector<8xi32>
+  amdgpu.wmma 16x16x16 %arg15 * %arg15 + %arg12 {clamp} : vector<8xi4>, vector<8xi4>, vector<8xi32>
   // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
-  amdgpu.wmma %arg16 * %arg16 + %arg13 {clamp} : vector<4xi4>, vector<4xi4>, vector<4xi32>
+  amdgpu.wmma 16x16x16 %arg16 * %arg16 + %arg13 {clamp} : vector<4xi4>, vector<4xi4>, vector<4xi32>
 
   func.return
 }
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index a8256b16ed8a1..aee4705b26958 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -120,9 +120,49 @@ func.func @no_negation(%a: f32, %b: f32, %c: vector<32xf32>) -> vector<32xf32> {
 
 // -----
 
-func.func @wmma(%arg0 : vector<16xf16>, %arg1 : vector<8xi32>) -> vector<8xi32> {
-  // expected-error at +1 {{'amdgpu.wmma' op Expected int sources with int destination}}
-  %0 = amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xi32>
+func.func @wmma_f16_i32(%arg0 : vector<16xf16>, %arg1 : vector<8xi32>) -> vector<8xi32> {
+  // expected-error at +1 {{'amdgpu.wmma' op expected int sources with int destination}}
+  %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xi32>
+  func.return %0 : vector<8xi32>
+}
+
+// -----
+
+func.func @wmma_i16_f32(%arg0 : vector<16xi8>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+  // expected-error at +1 {{'amdgpu.wmma' op expected float sources with float destination}}
+  %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xf32>
+  func.return %0 : vector<8xf32>
+}
+
+// -----
+
+func.func @wmma_no_k_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
+  // expected-error at +1 {{'amdgpu.wmma' expected 3 dimensions in MNK dimension list}}
+  %0 = amdgpu.wmma 16x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
+  func.return %0 : vector<8xi32>
+}
+
+// -----
+
+func.func @wmma_wrong_m_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
+  // expected-error at +1 {{'amdgpu.wmma' op attribute 'm' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 16 whose maximum value is 16}}
+  %0 = amdgpu.wmma 32x16x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
+  func.return %0 : vector<8xi32>
+}
+
+// -----
+
+func.func @wmma_wrong_n_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
+  // expected-error at +1 {{'amdgpu.wmma' op attribute 'n' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 16 whose maximum value is 16}}
+  %0 = amdgpu.wmma 16x32x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
+  func.return %0 : vector<8xi32>
+}
+
+// -----
+
+func.func @wmma_wrong_k_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
+  // expected-error at +1 {{'amdgpu.wmma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 16 whose maximum value is 32 whose value is a power of two}}
+  %0 = amdgpu.wmma 16x16x24 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
   func.return %0 : vector<8xi32>
 }
 
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index f9c6899dadfc1..a185eb612c9ac 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -565,13 +565,20 @@ func.func @mfma(%arg0 : f32, %arg1 : vector<32xf32>) -> vector<32xf32> {
   func.return %0 : vector<32xf32>
 }
 
-// CHECK-LABEL: func @wmma
-func.func @wmma(%arg0 : vector<16xf16>, %arg1 : vector<8xf16>) -> vector<8xf16> {
-  // CHECK: amdgpu.wmma
-  %0 = amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf16>
+// CHECK-LABEL: func @wmma_f16_16x16x16_f16
+func.func @wmma_f16_16x16x16_f16(%arg0 : vector<16xf16>, %arg1 : vector<8xf16>) -> vector<8xf16> {
+  // CHECK: amdgpu.wmma 16x16x16
+  %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf16>
   func.return %0 : vector<8xf16>
 }
 
+// CHECK-LABEL: func @wmma_i32_16x16x32_i4
+func.func @wmma_i32_16x16x32_i4(%arg0 : vector<16xi4>, %arg1 : vector<8xi32>) -> vector<8xi32> {
+  // CHECK: amdgpu.wmma 16x16x32
+  %0 = amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg1 : vector<16xi4>, vector<16xi4>, vector<8xi32>
+  func.return %0 : vector<8xi32>
+}
+
 // CHECK-LABEL: func @swizzle_bitmode
 func.func @swizzle_bitmode(%arg0 : f32) -> f32 {
   // CHECK: amdgpu.swizzle_bitmode

>From f8d697a1b426d5ab8a279d4457485607a1b2c802 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Thu, 23 Oct 2025 21:00:26 -0400
Subject: [PATCH 2/3] Fix namespace

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h | 4 ++--
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp        | 5 +++--
 2 files changed, 5 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
index b6fe61ff1afa2..dcd9f95a7561f 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
@@ -26,7 +26,7 @@
 
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.h.inc"
 
-namespace mlir {
+namespace mlir::amdgpu {
 /// Parser for the `custom<MNKDimensionList>` custom assembly format used by
 /// WMMAOp.
 ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m,
@@ -47,7 +47,7 @@ inline void printMNKDimensionList(OpAsmPrinter &printer, Operation *,
                                   IntegerAttr m, IntegerAttr n, IntegerAttr k) {
   printMNKDimensionList(printer, m, n, k);
 }
-} // namespace mlir
+} // namespace mlir::amdgpu
 
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.h.inc"
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index eb40374d61303..4c4965e67676e 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -361,8 +361,9 @@ LogicalResult ScaledExtPacked816Op::verify() {
 // WMMAOp
 //===----------------------------------------------------------------------===//
 
-ParseResult mlir::parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m,
-                                        IntegerAttr &n, IntegerAttr &k) {
+ParseResult mlir::amdgpu::parseMNKDimensionList(OpAsmParser &parser,
+                                                IntegerAttr &m, IntegerAttr &n,
+                                                IntegerAttr &k) {
   SmallVector<int64_t, 3> dimensions;
   if (parser.parseDimensionList(dimensions, false, false))
     return failure();

>From 050028521e00e36297966552e3d8278b6b4a4749 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Fri, 24 Oct 2025 10:20:16 -0400
Subject: [PATCH 3/3] Simplify confined attrs

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 8 ++++----
 mlir/include/mlir/IR/CommonAttrConstraints.td | 5 +++++
 mlir/test/Dialect/AMDGPU/invalid.mlir         | 6 +++---
 3 files changed, 12 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 3a808ff3a01e4..d74abc22acd5e 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -988,13 +988,13 @@ def AMDGPU_WMMAOp :
     AMDGPU_Op<"wmma", [AllTypesMatch<["destC", "destD"]>,
                        Pure]>,
     Arguments<(ins
-                   ConfinedAttr<I32Attr, [IntMinValue<16>, IntMaxValue<16>]>:$m,
-                   ConfinedAttr<I32Attr, [IntMinValue<16>, IntMaxValue<16>]>:$n,
-                   ConfinedAttr<I32Attr, [IntMinValue<16>, IntMaxValue<32>, IntPowerOf2]>:$k,
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$m,
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n,
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$k,
                    WMMAInTypes:$sourceA,
                    WMMAInTypes:$sourceB,
                    WMMAOutTypes:$destC,
-                   DefaultValuedAttr<ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<1>]>, "0">:$subwordOffset,
+                   DefaultValuedAttr<ConfinedAttr<I32Attr, [IntIsOneOf<[0, 1]>]>, "0">:$subwordOffset,
                    UnitAttr:$unsignedA,
                    UnitAttr:$unsignedB,
                    UnitAttr:$clamp)>,
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index e1869c1821b11..b7e168a3e6f86 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -804,6 +804,11 @@ def IntPositivePowerOf2 : AllAttrOf<[IntPositive, IntPowerOf2]>;
 
 class IntValidAlignment<Attr attr>: ConfinedAttr<attr, [IntPositivePowerOf2]>;
 
+class IntIsOneOf<list<int> values> : AttrConstraint<
+    CPred<"::llvm::is_contained({" # !interleave(!foreach(val, values, val), ", ") #
+                                "}, ::llvm::cast<::mlir::IntegerAttr>($_self).getInt())">,
+    "whose value is one of {" # !interleave(!foreach(val, values, val), ", ") # "}">;
+
 class ArrayMaxCount<int n> : AttrConstraint<
     CPred<"::llvm::cast<::mlir::ArrayAttr>($_self).size() <= " # n>,
     "with at most " # n # " elements">;
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index aee4705b26958..6a2518a40cc99 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -145,7 +145,7 @@ func.func @wmma_no_k_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector
 // -----
 
 func.func @wmma_wrong_m_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
-  // expected-error at +1 {{'amdgpu.wmma' op attribute 'm' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 16 whose maximum value is 16}}
+  // expected-error at +1 {{'amdgpu.wmma' op attribute 'm' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16}}}
   %0 = amdgpu.wmma 32x16x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
   func.return %0 : vector<8xi32>
 }
@@ -153,7 +153,7 @@ func.func @wmma_wrong_m_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vec
 // -----
 
 func.func @wmma_wrong_n_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
-  // expected-error at +1 {{'amdgpu.wmma' op attribute 'n' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 16 whose maximum value is 16}}
+  // expected-error at +1 {{'amdgpu.wmma' op attribute 'n' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16}}}
   %0 = amdgpu.wmma 16x32x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
   func.return %0 : vector<8xi32>
 }
@@ -161,7 +161,7 @@ func.func @wmma_wrong_n_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vec
 // -----
 
 func.func @wmma_wrong_k_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
-  // expected-error at +1 {{'amdgpu.wmma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 16 whose maximum value is 32 whose value is a power of two}}
+  // expected-error at +1 {{'amdgpu.wmma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}}
   %0 = amdgpu.wmma 16x16x24 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
   func.return %0 : vector<8xi32>
 }



More information about the Mlir-commits mailing list