[Mlir-commits] [mlir] [mlir][amdgpu] Add explicit intrinsic shape to wmma (PR #164920)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 23 17:56:02 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Jakub Kuderski (kuhar)
<details>
<summary>Changes</summary>
This is in preparation for adding support for gfx1250 wmma intrinsics that include much more possible shapes.
Instead of guessing the wave32/wave64 mode based on element types and vector sizes, require the intrinsic shapes to be set explicitly as attributes.
---
Patch is 31.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/164920.diff
8 Files Affected:
- (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+29-16)
- (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h (+24-1)
- (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+40-30)
- (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+34-27)
- (renamed) mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir (+14-13)
- (modified) mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir (+23-23)
- (modified) mlir/test/Dialect/AMDGPU/invalid.mlir (+43-3)
- (modified) mlir/test/Dialect/AMDGPU/ops.mlir (+11-4)
``````````diff
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 7184de93bfacb..3a808ff3a01e4 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -912,12 +912,10 @@ def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN
VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
// wmma
-def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<
- [4, 8, 16],
- [F16, BF16,
- I8, SI8, UI8,
- I<4>, SI<4>, UI<4>,
- F8E4M3FN, F8E5M2]>]>;
+def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
+ VectorOfLengthAndType<[4, 8, 16], [I8, SI8, UI8]>,
+ VectorOfLengthAndType<[4, 8], [F8E4M3FN, F8E5M2]>,
+ VectorOfLengthAndType<[4, 8, 16], [I<4>, SI<4>, UI<4>]>]>;
def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>;
@@ -968,6 +966,14 @@ def AMDGPU_MFMAOp :
The negateA, negateB, and negateC flags are only supported for double-precision
operations on gfx94x.
+
+ Example:
+ ```mlir
+ %0 = amdgpu.mfma %matA * %matB + %matC
+ { abid = 1 : i32, cbsz = 1 : i32,
+ m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32 }
+ blgp = bcast_second_32 : f32, f32, vector<32xf32>
+ ```
}];
let assemblyFormat = [{
$sourceA `*` $sourceB `+` $destC
@@ -982,6 +988,9 @@ def AMDGPU_WMMAOp :
AMDGPU_Op<"wmma", [AllTypesMatch<["destC", "destD"]>,
Pure]>,
Arguments<(ins
+ ConfinedAttr<I32Attr, [IntMinValue<16>, IntMaxValue<16>]>:$m,
+ ConfinedAttr<I32Attr, [IntMinValue<16>, IntMaxValue<16>]>:$n,
+ ConfinedAttr<I32Attr, [IntMinValue<16>, IntMaxValue<32>, IntPowerOf2]>:$k,
WMMAInTypes:$sourceA,
WMMAInTypes:$sourceB,
WMMAOutTypes:$destC,
@@ -990,28 +999,32 @@ def AMDGPU_WMMAOp :
UnitAttr:$unsignedB,
UnitAttr:$clamp)>,
Results<(outs WMMAOutTypes: $destD)> {
- let summary = "MLIR wrapper for RDNA3 wmma instructions";
+ let summary = "MLIR wrapper for wmma instructions";
let description = [{
- The `amdgpu.wmma` op is an MLIR wrapper around intrinsics
- for various `wmma` instructions in the RDNA3 or RDNA4 architecture, which
- perform a 16x16 * 16x16 matrix multiplication for different data types.
- Note that in gfx12/RDNA4, there is also a 16x32 * 32x16 instruction for 4-bit
- integer inputs.
+ The `amdgpu.wmma` op is an MLIR wrapper around intrinsics for various `wmma`
+ instructions in the AMDGPU architecture, which perform matrix multiplication.
+ Note that all wmma intrinsics have M=N=16 dimensions but vary by in allowed K
+ dimensions.
On gfx11/RDNA3, emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16
(or 16xbf16) vector containing only 8 valid values:
- If `subwordOffset` is 0, then the output is stored at indices 0, 2, 4, ..., 14.
- If `subwordOffset` is 1, then the output is stored at indices 1, 3, 5, ..., 15.
- On gfx12/RDNA4, the result is instead returned as a vector<8 x f16/bf16> where
- all values are valid and the `subwordOffset` must be `0`, as it cannot be used.
+ On gfx12/RDNA4 and gfx1250, the result is instead returned as vector where all
+ the values are valid and the `subwordOffset` must be `0`, as it cannot be used.
`unsignedA` and `unsignedB` flag that the `int8` LLVM inputs are unsigned.
- The `clamp` flag is used to saturate the output of type T to numeric_limits<T>::max()
+ The `clamp` flag is used to saturate the output of type T to `numeric_limits<T>::max()`
in case of overflow.
+
+ Example:
+ ```mlir
+ %0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<16xf16>, vector<16xf16>, vector<8xf16>
+ ```
}];
let assemblyFormat = [{
- $sourceA `*` $sourceB `+` $destC
+ custom<MNKDimensionList>($m, $n, $k) $sourceA `*` $sourceB `+` $destC
attr-dict
`:` type($sourceA) `,` type($sourceB) `,` type($destC)
}];
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
index 3de57c923178a..b6fe61ff1afa2 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
@@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
//
// This file declares a dialect for MLIR wrappers around AMDGPU-specific
-// intrinssics and for other AMD GPU-specific functionality.
+// intrinsics and for other AMD GPU-specific functionality.
//
//===----------------------------------------------------------------------===//
@@ -26,6 +26,29 @@
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.h.inc"
+namespace mlir {
+/// Parser for the `custom<MNKDimensionList>` custom assembly format used by
+/// WMMAOp.
+ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m,
+ IntegerAttr &n, IntegerAttr &k);
+inline ParseResult parseMNKDimensionList(OpAsmParser &parser, Operation *,
+ IntegerAttr &m, IntegerAttr &n,
+ IntegerAttr &k) {
+ return parseMNKDimensionList(parser, m, n, k);
+}
+
+/// Printer for the `custom<MNKDimensionList>` custom assembly format used by
+/// WMMAOp.
+inline void printMNKDimensionList(OpAsmPrinter &printer, IntegerAttr m,
+ IntegerAttr n, IntegerAttr k) {
+ printer.printDimensionList(ArrayRef{m.getInt(), n.getInt(), k.getInt()});
+}
+inline void printMNKDimensionList(OpAsmPrinter &printer, Operation *,
+ IntegerAttr m, IntegerAttr n, IntegerAttr k) {
+ printMNKDimensionList(printer, m, n, k);
+}
+} // namespace mlir
+
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.h.inc"
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 9b154350cd913..478b6aaaec83a 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
@@ -993,28 +994,36 @@ mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
/// on the architecture you are compiling for.
static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
Chipset chipset) {
- auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
- auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType());
- auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
- auto elemSourceType = sourceVectorType.getElementType();
- auto elemBSourceType = sourceBVectorType.getElementType();
- auto elemDestType = destVectorType.getElementType();
-
- if (elemSourceType.isF16() && elemDestType.isF32())
- return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
- if (elemSourceType.isBF16() && elemDestType.isF32())
- return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
- if (elemSourceType.isF16() && elemDestType.isF16())
- return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
- if (elemSourceType.isBF16() && elemDestType.isBF16())
- return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
- if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
- return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
- if (chipset.majorVersion == 11) {
- if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
- return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
+ auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
+ auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
+ auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
+ Type elemSourceType = sourceVectorType.getElementType();
+ Type elemBSourceType = sourceBVectorType.getElementType();
+ Type elemDestType = destVectorType.getElementType();
+
+ const uint32_t k = wmma.getK();
+
+ if (k == 16) {
+ if (elemSourceType.isF16() && elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
+ if (elemSourceType.isBF16() && elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
+ if (elemSourceType.isF16() && elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
+ if (elemSourceType.isBF16() && elemDestType.isBF16())
+ return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
+ if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
+ return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
+ if (chipset.majorVersion == 11) {
+ if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
+ return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
+ }
}
- if (chipset.majorVersion >= 12) {
+ if (chipset.majorVersion < 12)
+ return std::nullopt;
+
+ // gfx12+
+ if (k == 16) {
if (isa<Float8E4M3FNType>(elemSourceType) &&
isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
@@ -1027,17 +1036,18 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
if (isa<Float8E5M2Type>(elemSourceType) &&
isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
- if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
- bool isWave64 = destVectorType.getNumElements() == 4;
- // This is the ambiguous case. 8 inputs to the wave64 version means that
- // we want the 16x16x32 version, but for wave32 they mean the short form.
- bool has8Inputs = sourceVectorType.getNumElements() == 8;
- if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
- return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+ if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
- }
+
+ return std::nullopt;
}
- return std::nullopt;
+ if (k == 32) {
+ if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
+ return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+ return std::nullopt;
+ }
+
+ llvm_unreachable("unhandled WMMA case");
}
namespace {
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 61166db0ff210..eb40374d61303 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -360,45 +360,52 @@ LogicalResult ScaledExtPacked816Op::verify() {
//===----------------------------------------------------------------------===//
// WMMAOp
//===----------------------------------------------------------------------===//
-LogicalResult WMMAOp::verify() {
- Type sourceAType = getSourceA().getType();
- Type sourceBType = getSourceB().getType();
- Type destType = getDestC().getType();
- VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
- VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType);
- VectorType destVectorType = dyn_cast<VectorType>(destType);
+ParseResult mlir::parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m,
+ IntegerAttr &n, IntegerAttr &k) {
+ SmallVector<int64_t, 3> dimensions;
+ if (parser.parseDimensionList(dimensions, false, false))
+ return failure();
+ if (dimensions.size() != 3)
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected 3 dimensions in MNK dimension list";
- Type sourceAElemType = sourceVectorAType.getElementType();
- Type sourceBElemType = sourceVectorBType.getElementType();
- Type destElemType = destVectorType.getElementType();
+ m = parser.getBuilder().getI32IntegerAttr(dimensions[0]);
+ n = parser.getBuilder().getI32IntegerAttr(dimensions[1]);
+ k = parser.getBuilder().getI32IntegerAttr(dimensions[2]);
+ return success();
+}
- if (sourceVectorAType.getNumElements() !=
- sourceVectorBType.getNumElements()) {
+LogicalResult WMMAOp::verify() {
+ auto sourceAType = cast<VectorType>(getSourceA().getType());
+ auto sourceBType = cast<VectorType>(getSourceB().getType());
+ auto destType = cast<VectorType>(getDestC().getType());
+
+ Type sourceAElemType = sourceAType.getElementType();
+ Type sourceBElemType = sourceBType.getElementType();
+ if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
return emitOpError("source vectors have different lengths: ")
- << sourceVectorAType << " vs. " << sourceVectorBType;
+ << sourceAType << " vs. " << sourceBType;
}
- bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
- bool isSrcFloat =
- isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
- sourceAElemType);
-
- if (isDestFloat && !isSrcFloat) {
- return emitOpError("Expected float sources with float destination");
- }
+ bool isDestFloat = destType.getElementType().isFloat();
+ bool isSrcFloat = sourceAElemType.isFloat();
- if (!isDestFloat && isSrcFloat) {
- return emitOpError("Expected int sources with int destination");
- }
+ if (isDestFloat && !isSrcFloat)
+ return emitOpError("expected float sources with float destination");
+ if (!isDestFloat && isSrcFloat)
+ return emitOpError("expected int sources with int destination");
- if (sourceAElemType != sourceBElemType &&
- !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
- isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
+ if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
return emitOpError(
"source element types much match (except for fp8) but have ")
<< sourceAType << " and " << sourceBType;
}
+
+ if (!sourceAElemType.isInteger(4) && getK() != 16) {
+ return emitOpError("K dimension must be 16 for source element type ")
+ << sourceAElemType;
+ }
return success();
}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
similarity index 59%
rename from mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
rename to mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
index 638a7c3f8c1c5..d1301d0089220 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
@@ -1,35 +1,36 @@
-// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s
+
// CHECK-LABEL: @wmma_to_rocdl
func.func @wmma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : vector<4xf32>,
%arg3 : vector<16xbf16>, %arg4 : vector<8xf16>, %arg5 : vector<8xbf16>,
%arg6 : vector<16xi8>, %arg7 : vector<8xi32>, %arg8 : vector<4xi32>,
%arg9 : vector<16xui8>, %arg10 : vector<16xi4>, %arg11 : vector<8xi4>) {
// CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
- amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32>
+ amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32>
// CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32>
- amdgpu.wmma %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32>
+ amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32>
// CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xf32>) -> vector<8xf32>
- amdgpu.wmma %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
+ amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
// CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<4xf32>) -> vector<4xf32>
- amdgpu.wmma %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32>
+ amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32>
// CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
- amdgpu.wmma %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
+ amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
// CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16>
- amdgpu.wmma %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
+ amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
// CHECK: %[[raw_bf16x16:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16>
// CHECK-NEXT: llvm.bitcast %[[raw_bf16x16]] : vector<16xi16> to vector<16xbf16>
- amdgpu.wmma %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
+ amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
// CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16>
// CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16>
- amdgpu.wmma %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
+ amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
// CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32>
- amdgpu.wmma %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32>
+ amdgpu.wmma 16x16x16 %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32>
// CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32>
- amdgpu.wmma %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32>
+ amdgpu.wmma 16x16x16 %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32>
// CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
- amdgpu.wmma %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32>
+ amdgpu.wmma 16x16x16 %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32>
// CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
- amdgpu.wmma %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32>
+ amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32>
func.return
}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
index 94a1b78d5f040..b897323340402 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/164920
More information about the Mlir-commits
mailing list