[Mlir-commits] [mlir] [mlir][amdgpu] Add explicit intrinsic shape to wmma (PR #164920)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 23 17:56:02 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Jakub Kuderski (kuhar)

<details>
<summary>Changes</summary>

This is in preparation for adding support for gfx1250 wmma intrinsics that include much more possible shapes.

Instead of guessing the wave32/wave64 mode based on element types and vector sizes, require the intrinsic shapes to be set explicitly as attributes.

---

Patch is 31.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/164920.diff


8 Files Affected:

- (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+29-16) 
- (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h (+24-1) 
- (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+40-30) 
- (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+34-27) 
- (renamed) mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir (+14-13) 
- (modified) mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir (+23-23) 
- (modified) mlir/test/Dialect/AMDGPU/invalid.mlir (+43-3) 
- (modified) mlir/test/Dialect/AMDGPU/ops.mlir (+11-4) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 7184de93bfacb..3a808ff3a01e4 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -912,12 +912,10 @@ def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN
                                    VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
 def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
 // wmma
-def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<
-                             [4, 8, 16],
-                             [F16, BF16,
-                              I8, SI8, UI8,
-                              I<4>, SI<4>, UI<4>,
-                              F8E4M3FN, F8E5M2]>]>;
+def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
+                             VectorOfLengthAndType<[4, 8, 16], [I8, SI8, UI8]>,
+                             VectorOfLengthAndType<[4, 8], [F8E4M3FN, F8E5M2]>,
+                             VectorOfLengthAndType<[4, 8, 16], [I<4>, SI<4>, UI<4>]>]>;
 def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
                               VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>;
 
@@ -968,6 +966,14 @@ def AMDGPU_MFMAOp :
 
     The negateA, negateB, and negateC flags are only supported for double-precision
     operations on gfx94x.
+
+    Example:
+    ```mlir
+      %0 = amdgpu.mfma %matA * %matB + %matC
+        { abid = 1 : i32, cbsz = 1 : i32,
+          m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32 }
+        blgp = bcast_second_32 : f32, f32, vector<32xf32>
+    ```
   }];
   let assemblyFormat = [{
     $sourceA `*` $sourceB `+` $destC
@@ -982,6 +988,9 @@ def AMDGPU_WMMAOp :
     AMDGPU_Op<"wmma", [AllTypesMatch<["destC", "destD"]>,
                        Pure]>,
     Arguments<(ins
+                   ConfinedAttr<I32Attr, [IntMinValue<16>, IntMaxValue<16>]>:$m,
+                   ConfinedAttr<I32Attr, [IntMinValue<16>, IntMaxValue<16>]>:$n,
+                   ConfinedAttr<I32Attr, [IntMinValue<16>, IntMaxValue<32>, IntPowerOf2]>:$k,
                    WMMAInTypes:$sourceA,
                    WMMAInTypes:$sourceB,
                    WMMAOutTypes:$destC,
@@ -990,28 +999,32 @@ def AMDGPU_WMMAOp :
                    UnitAttr:$unsignedB,
                    UnitAttr:$clamp)>,
     Results<(outs WMMAOutTypes: $destD)> {
-  let summary = "MLIR wrapper for RDNA3 wmma instructions";
+  let summary = "MLIR wrapper for wmma instructions";
   let description = [{
-    The `amdgpu.wmma` op is an MLIR wrapper around intrinsics
-    for various `wmma` instructions in the RDNA3 or RDNA4 architecture, which
-    perform a 16x16 * 16x16 matrix multiplication for different data types.
-    Note that in gfx12/RDNA4, there is also a 16x32 * 32x16 instruction for 4-bit
-    integer inputs.
+    The `amdgpu.wmma` op is an MLIR wrapper around intrinsics for various `wmma`
+    instructions in the AMDGPU architecture, which perform matrix multiplication.
+    Note that all wmma intrinsics have M=N=16 dimensions but vary by in allowed K
+    dimensions.
 
     On gfx11/RDNA3, emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16
     (or 16xbf16) vector containing only 8 valid values:
       - If `subwordOffset` is 0, then the output is stored at indices 0, 2, 4, ..., 14.
       - If `subwordOffset` is 1, then the output is stored at indices 1, 3, 5, ..., 15.
-    On gfx12/RDNA4, the result is instead returned as a vector<8 x f16/bf16> where
-    all values are valid and the `subwordOffset` must be `0`, as it cannot be used.
+    On gfx12/RDNA4 and gfx1250, the result is instead returned as vector where all
+    the values are valid and the `subwordOffset` must be `0`, as it cannot be used.
 
     `unsignedA` and `unsignedB` flag that the `int8` LLVM inputs are unsigned.
 
-    The `clamp` flag is used to saturate the output of type T to numeric_limits<T>::max()
+    The `clamp` flag is used to saturate the output of type T to `numeric_limits<T>::max()`
     in case of overflow.
+
+    Example:
+    ```mlir
+      %0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<16xf16>, vector<16xf16>, vector<8xf16>
+    ```
   }];
   let assemblyFormat = [{
-    $sourceA `*` $sourceB `+` $destC
+    custom<MNKDimensionList>($m, $n, $k) $sourceA `*` $sourceB `+` $destC
     attr-dict
     `:` type($sourceA) `,` type($sourceB) `,` type($destC)
   }];
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
index 3de57c923178a..b6fe61ff1afa2 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
@@ -7,7 +7,7 @@
 //===----------------------------------------------------------------------===//
 //
 // This file declares a dialect for MLIR wrappers around AMDGPU-specific
-// intrinssics and for other AMD GPU-specific functionality.
+// intrinsics and for other AMD GPU-specific functionality.
 //
 //===----------------------------------------------------------------------===//
 
@@ -26,6 +26,29 @@
 
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.h.inc"
 
+namespace mlir {
+/// Parser for the `custom<MNKDimensionList>` custom assembly format used by
+/// WMMAOp.
+ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m,
+                                  IntegerAttr &n, IntegerAttr &k);
+inline ParseResult parseMNKDimensionList(OpAsmParser &parser, Operation *,
+                                         IntegerAttr &m, IntegerAttr &n,
+                                         IntegerAttr &k) {
+  return parseMNKDimensionList(parser, m, n, k);
+}
+
+/// Printer for the `custom<MNKDimensionList>` custom assembly format used by
+/// WMMAOp.
+inline void printMNKDimensionList(OpAsmPrinter &printer, IntegerAttr m,
+                                  IntegerAttr n, IntegerAttr k) {
+  printer.printDimensionList(ArrayRef{m.getInt(), n.getInt(), k.getInt()});
+}
+inline void printMNKDimensionList(OpAsmPrinter &printer, Operation *,
+                                  IntegerAttr m, IntegerAttr n, IntegerAttr k) {
+  printMNKDimensionList(printer, m, n, k);
+}
+} // namespace mlir
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.h.inc"
 
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 9b154350cd913..478b6aaaec83a 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
@@ -993,28 +994,36 @@ mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
 /// on the architecture you are compiling for.
 static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
                                                   Chipset chipset) {
-  auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
-  auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType());
-  auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
-  auto elemSourceType = sourceVectorType.getElementType();
-  auto elemBSourceType = sourceBVectorType.getElementType();
-  auto elemDestType = destVectorType.getElementType();
-
-  if (elemSourceType.isF16() && elemDestType.isF32())
-    return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
-  if (elemSourceType.isBF16() && elemDestType.isF32())
-    return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
-  if (elemSourceType.isF16() && elemDestType.isF16())
-    return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
-  if (elemSourceType.isBF16() && elemDestType.isBF16())
-    return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
-  if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
-    return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
-  if (chipset.majorVersion == 11) {
-    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
-      return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
+  auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
+  auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
+  auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
+  Type elemSourceType = sourceVectorType.getElementType();
+  Type elemBSourceType = sourceBVectorType.getElementType();
+  Type elemDestType = destVectorType.getElementType();
+
+  const uint32_t k = wmma.getK();
+
+  if (k == 16) {
+    if (elemSourceType.isF16() && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
+    if (elemSourceType.isBF16() && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
+    if (elemSourceType.isF16() && elemDestType.isF16())
+      return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
+    if (elemSourceType.isBF16() && elemDestType.isBF16())
+      return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
+    if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
+      return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
+    if (chipset.majorVersion == 11) {
+      if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
+        return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
+    }
   }
-  if (chipset.majorVersion >= 12) {
+  if (chipset.majorVersion < 12)
+    return std::nullopt;
+
+  // gfx12+
+  if (k == 16) {
     if (isa<Float8E4M3FNType>(elemSourceType) &&
         isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
@@ -1027,17 +1036,18 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
     if (isa<Float8E5M2Type>(elemSourceType) &&
         isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
-    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
-      bool isWave64 = destVectorType.getNumElements() == 4;
-      // This is the ambiguous case. 8 inputs to the wave64 version means that
-      // we want the 16x16x32 version, but for wave32 they mean the short form.
-      bool has8Inputs = sourceVectorType.getNumElements() == 8;
-      if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
-        return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
       return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
-    }
+
+    return std::nullopt;
   }
-  return std::nullopt;
+  if (k == 32) {
+    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
+      return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+    return std::nullopt;
+  }
+
+  llvm_unreachable("unhandled WMMA case");
 }
 
 namespace {
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 61166db0ff210..eb40374d61303 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -360,45 +360,52 @@ LogicalResult ScaledExtPacked816Op::verify() {
 //===----------------------------------------------------------------------===//
 // WMMAOp
 //===----------------------------------------------------------------------===//
-LogicalResult WMMAOp::verify() {
-  Type sourceAType = getSourceA().getType();
-  Type sourceBType = getSourceB().getType();
-  Type destType = getDestC().getType();
 
-  VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
-  VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType);
-  VectorType destVectorType = dyn_cast<VectorType>(destType);
+ParseResult mlir::parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m,
+                                        IntegerAttr &n, IntegerAttr &k) {
+  SmallVector<int64_t, 3> dimensions;
+  if (parser.parseDimensionList(dimensions, false, false))
+    return failure();
+  if (dimensions.size() != 3)
+    return parser.emitError(parser.getCurrentLocation())
+           << "expected 3 dimensions in MNK dimension list";
 
-  Type sourceAElemType = sourceVectorAType.getElementType();
-  Type sourceBElemType = sourceVectorBType.getElementType();
-  Type destElemType = destVectorType.getElementType();
+  m = parser.getBuilder().getI32IntegerAttr(dimensions[0]);
+  n = parser.getBuilder().getI32IntegerAttr(dimensions[1]);
+  k = parser.getBuilder().getI32IntegerAttr(dimensions[2]);
+  return success();
+}
 
-  if (sourceVectorAType.getNumElements() !=
-      sourceVectorBType.getNumElements()) {
+LogicalResult WMMAOp::verify() {
+  auto sourceAType = cast<VectorType>(getSourceA().getType());
+  auto sourceBType = cast<VectorType>(getSourceB().getType());
+  auto destType = cast<VectorType>(getDestC().getType());
+
+  Type sourceAElemType = sourceAType.getElementType();
+  Type sourceBElemType = sourceBType.getElementType();
+  if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
     return emitOpError("source vectors have different lengths: ")
-           << sourceVectorAType << " vs. " << sourceVectorBType;
+           << sourceAType << " vs. " << sourceBType;
   }
 
-  bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
-  bool isSrcFloat =
-      isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
-          sourceAElemType);
-
-  if (isDestFloat && !isSrcFloat) {
-    return emitOpError("Expected float sources with float destination");
-  }
+  bool isDestFloat = destType.getElementType().isFloat();
+  bool isSrcFloat = sourceAElemType.isFloat();
 
-  if (!isDestFloat && isSrcFloat) {
-    return emitOpError("Expected int sources with int destination");
-  }
+  if (isDestFloat && !isSrcFloat)
+    return emitOpError("expected float sources with float destination");
+  if (!isDestFloat && isSrcFloat)
+    return emitOpError("expected int sources with int destination");
 
-  if (sourceAElemType != sourceBElemType &&
-      !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
-        isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
+  if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
     return emitOpError(
                "source element types much match (except for fp8) but have ")
            << sourceAType << " and " << sourceBType;
   }
+
+  if (!sourceAElemType.isInteger(4) && getK() != 16) {
+    return emitOpError("K dimension must be 16 for source element type ")
+           << sourceAElemType;
+  }
   return success();
 }
 
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
similarity index 59%
rename from mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
rename to mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
index 638a7c3f8c1c5..d1301d0089220 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
@@ -1,35 +1,36 @@
-// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s
+
 // CHECK-LABEL: @wmma_to_rocdl
 func.func @wmma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : vector<4xf32>,
                          %arg3 : vector<16xbf16>, %arg4 : vector<8xf16>, %arg5 : vector<8xbf16>,
                          %arg6 : vector<16xi8>, %arg7 : vector<8xi32>, %arg8 : vector<4xi32>,
                          %arg9 : vector<16xui8>, %arg10 : vector<16xi4>, %arg11 : vector<8xi4>) {
   // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
-  amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32>
+  amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32>
   // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32>
-  amdgpu.wmma %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32>
+  amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32>
   // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xf32>) -> vector<8xf32>
-  amdgpu.wmma %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
+  amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
   // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<4xf32>) -> vector<4xf32>
-  amdgpu.wmma %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32>
+  amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32>
   // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
-  amdgpu.wmma %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
+  amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
   // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16>
-  amdgpu.wmma %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
+  amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
   // CHECK: %[[raw_bf16x16:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16>
   // CHECK-NEXT: llvm.bitcast %[[raw_bf16x16]] : vector<16xi16> to vector<16xbf16>
-  amdgpu.wmma %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
+  amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
   // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16>
   // CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16>
-  amdgpu.wmma %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
+  amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
   // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32>
-  amdgpu.wmma %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32>
+  amdgpu.wmma 16x16x16 %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32>
   // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32>
-  amdgpu.wmma %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32>
+  amdgpu.wmma 16x16x16 %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32>
   // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
-  amdgpu.wmma %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32>
+  amdgpu.wmma 16x16x16 %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32>
   // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
-  amdgpu.wmma %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32>
+  amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32>
 
   func.return
 }
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
index 94a1b78d5f040..b897323340402 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/164920


More information about the Mlir-commits mailing list