[Mlir-commits] [mlir] dc5f274 - [mlir][amdgpu] Add explicit intrinsic shape to wmma (#164920)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Oct 24 09:21:37 PDT 2025
Author: Jakub Kuderski
Date: 2025-10-24T12:21:33-04:00
New Revision: dc5f2745604d4c5a003e909574b531662b372355
URL: https://github.com/llvm/llvm-project/commit/dc5f2745604d4c5a003e909574b531662b372355
DIFF: https://github.com/llvm/llvm-project/commit/dc5f2745604d4c5a003e909574b531662b372355.diff
LOG: [mlir][amdgpu] Add explicit intrinsic shape to wmma (#164920)
This is in preparation for adding support for gfx1250 wmma intrinsics
that include much more possible shapes.
Instead of guessing the wave32/wave64 mode based on element types and
vector sizes, require the intrinsic shapes to be set explicitly as
attributes.
Added:
mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
Modified:
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
mlir/include/mlir/IR/CommonAttrConstraints.td
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
mlir/test/Dialect/AMDGPU/invalid.mlir
mlir/test/Dialect/AMDGPU/ops.mlir
Removed:
mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 7184de93bfacb..d74abc22acd5e 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -912,12 +912,10 @@ def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN
VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
// wmma
-def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<
- [4, 8, 16],
- [F16, BF16,
- I8, SI8, UI8,
- I<4>, SI<4>, UI<4>,
- F8E4M3FN, F8E5M2]>]>;
+def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
+ VectorOfLengthAndType<[4, 8, 16], [I8, SI8, UI8]>,
+ VectorOfLengthAndType<[4, 8], [F8E4M3FN, F8E5M2]>,
+ VectorOfLengthAndType<[4, 8, 16], [I<4>, SI<4>, UI<4>]>]>;
def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>;
@@ -968,6 +966,14 @@ def AMDGPU_MFMAOp :
The negateA, negateB, and negateC flags are only supported for double-precision
operations on gfx94x.
+
+ Example:
+ ```mlir
+ %0 = amdgpu.mfma %matA * %matB + %matC
+ { abid = 1 : i32, cbsz = 1 : i32,
+ m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32 }
+ blgp = bcast_second_32 : f32, f32, vector<32xf32>
+ ```
}];
let assemblyFormat = [{
$sourceA `*` $sourceB `+` $destC
@@ -982,36 +988,43 @@ def AMDGPU_WMMAOp :
AMDGPU_Op<"wmma", [AllTypesMatch<["destC", "destD"]>,
Pure]>,
Arguments<(ins
+ ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$m,
+ ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n,
+ ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$k,
WMMAInTypes:$sourceA,
WMMAInTypes:$sourceB,
WMMAOutTypes:$destC,
- DefaultValuedAttr<ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<1>]>, "0">:$subwordOffset,
+ DefaultValuedAttr<ConfinedAttr<I32Attr, [IntIsOneOf<[0, 1]>]>, "0">:$subwordOffset,
UnitAttr:$unsignedA,
UnitAttr:$unsignedB,
UnitAttr:$clamp)>,
Results<(outs WMMAOutTypes: $destD)> {
- let summary = "MLIR wrapper for RDNA3 wmma instructions";
+ let summary = "MLIR wrapper for wmma instructions";
let description = [{
- The `amdgpu.wmma` op is an MLIR wrapper around intrinsics
- for various `wmma` instructions in the RDNA3 or RDNA4 architecture, which
- perform a 16x16 * 16x16 matrix multiplication for
diff erent data types.
- Note that in gfx12/RDNA4, there is also a 16x32 * 32x16 instruction for 4-bit
- integer inputs.
+ The `amdgpu.wmma` op is an MLIR wrapper around intrinsics for various `wmma`
+ instructions in the AMDGPU architecture, which perform matrix multiplication.
+ Note that all wmma intrinsics have M=N=16 dimensions but vary by in allowed K
+ dimensions.
On gfx11/RDNA3, emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16
(or 16xbf16) vector containing only 8 valid values:
- If `subwordOffset` is 0, then the output is stored at indices 0, 2, 4, ..., 14.
- If `subwordOffset` is 1, then the output is stored at indices 1, 3, 5, ..., 15.
- On gfx12/RDNA4, the result is instead returned as a vector<8 x f16/bf16> where
- all values are valid and the `subwordOffset` must be `0`, as it cannot be used.
+ On gfx12/RDNA4 and gfx1250, the result is instead returned as vector where all
+ the values are valid and the `subwordOffset` must be `0`, as it cannot be used.
`unsignedA` and `unsignedB` flag that the `int8` LLVM inputs are unsigned.
- The `clamp` flag is used to saturate the output of type T to numeric_limits<T>::max()
+ The `clamp` flag is used to saturate the output of type T to `numeric_limits<T>::max()`
in case of overflow.
+
+ Example:
+ ```mlir
+ %0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<16xf16>, vector<16xf16>, vector<8xf16>
+ ```
}];
let assemblyFormat = [{
- $sourceA `*` $sourceB `+` $destC
+ custom<MNKDimensionList>($m, $n, $k) $sourceA `*` $sourceB `+` $destC
attr-dict
`:` type($sourceA) `,` type($sourceB) `,` type($destC)
}];
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
index 3de57c923178a..dcd9f95a7561f 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
@@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
//
// This file declares a dialect for MLIR wrappers around AMDGPU-specific
-// intrinssics and for other AMD GPU-specific functionality.
+// intrinsics and for other AMD GPU-specific functionality.
//
//===----------------------------------------------------------------------===//
@@ -26,6 +26,29 @@
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.h.inc"
+namespace mlir::amdgpu {
+/// Parser for the `custom<MNKDimensionList>` custom assembly format used by
+/// WMMAOp.
+ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m,
+ IntegerAttr &n, IntegerAttr &k);
+inline ParseResult parseMNKDimensionList(OpAsmParser &parser, Operation *,
+ IntegerAttr &m, IntegerAttr &n,
+ IntegerAttr &k) {
+ return parseMNKDimensionList(parser, m, n, k);
+}
+
+/// Printer for the `custom<MNKDimensionList>` custom assembly format used by
+/// WMMAOp.
+inline void printMNKDimensionList(OpAsmPrinter &printer, IntegerAttr m,
+ IntegerAttr n, IntegerAttr k) {
+ printer.printDimensionList(ArrayRef{m.getInt(), n.getInt(), k.getInt()});
+}
+inline void printMNKDimensionList(OpAsmPrinter &printer, Operation *,
+ IntegerAttr m, IntegerAttr n, IntegerAttr k) {
+ printMNKDimensionList(printer, m, n, k);
+}
+} // namespace mlir::amdgpu
+
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.h.inc"
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index e1869c1821b11..b7e168a3e6f86 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -804,6 +804,11 @@ def IntPositivePowerOf2 : AllAttrOf<[IntPositive, IntPowerOf2]>;
class IntValidAlignment<Attr attr>: ConfinedAttr<attr, [IntPositivePowerOf2]>;
+class IntIsOneOf<list<int> values> : AttrConstraint<
+ CPred<"::llvm::is_contained({" # !interleave(!foreach(val, values, val), ", ") #
+ "}, ::llvm::cast<::mlir::IntegerAttr>($_self).getInt())">,
+ "whose value is one of {" # !interleave(!foreach(val, values, val), ", ") # "}">;
+
class ArrayMaxCount<int n> : AttrConstraint<
CPred<"::llvm::cast<::mlir::ArrayAttr>($_self).size() <= " # n>,
"with at most " # n # " elements">;
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 9b154350cd913..478b6aaaec83a 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
@@ -993,28 +994,36 @@ mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
/// on the architecture you are compiling for.
static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
Chipset chipset) {
- auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
- auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType());
- auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
- auto elemSourceType = sourceVectorType.getElementType();
- auto elemBSourceType = sourceBVectorType.getElementType();
- auto elemDestType = destVectorType.getElementType();
-
- if (elemSourceType.isF16() && elemDestType.isF32())
- return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
- if (elemSourceType.isBF16() && elemDestType.isF32())
- return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
- if (elemSourceType.isF16() && elemDestType.isF16())
- return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
- if (elemSourceType.isBF16() && elemDestType.isBF16())
- return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
- if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
- return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
- if (chipset.majorVersion == 11) {
- if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
- return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
+ auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
+ auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
+ auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
+ Type elemSourceType = sourceVectorType.getElementType();
+ Type elemBSourceType = sourceBVectorType.getElementType();
+ Type elemDestType = destVectorType.getElementType();
+
+ const uint32_t k = wmma.getK();
+
+ if (k == 16) {
+ if (elemSourceType.isF16() && elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
+ if (elemSourceType.isBF16() && elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
+ if (elemSourceType.isF16() && elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
+ if (elemSourceType.isBF16() && elemDestType.isBF16())
+ return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
+ if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
+ return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
+ if (chipset.majorVersion == 11) {
+ if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
+ return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
+ }
}
- if (chipset.majorVersion >= 12) {
+ if (chipset.majorVersion < 12)
+ return std::nullopt;
+
+ // gfx12+
+ if (k == 16) {
if (isa<Float8E4M3FNType>(elemSourceType) &&
isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
@@ -1027,17 +1036,18 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
if (isa<Float8E5M2Type>(elemSourceType) &&
isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
- if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
- bool isWave64 = destVectorType.getNumElements() == 4;
- // This is the ambiguous case. 8 inputs to the wave64 version means that
- // we want the 16x16x32 version, but for wave32 they mean the short form.
- bool has8Inputs = sourceVectorType.getNumElements() == 8;
- if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
- return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+ if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
- }
+
+ return std::nullopt;
}
- return std::nullopt;
+ if (k == 32) {
+ if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
+ return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+ return std::nullopt;
+ }
+
+ llvm_unreachable("unhandled WMMA case");
}
namespace {
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 61166db0ff210..4c4965e67676e 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -360,45 +360,53 @@ LogicalResult ScaledExtPacked816Op::verify() {
//===----------------------------------------------------------------------===//
// WMMAOp
//===----------------------------------------------------------------------===//
-LogicalResult WMMAOp::verify() {
- Type sourceAType = getSourceA().getType();
- Type sourceBType = getSourceB().getType();
- Type destType = getDestC().getType();
- VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
- VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType);
- VectorType destVectorType = dyn_cast<VectorType>(destType);
+ParseResult mlir::amdgpu::parseMNKDimensionList(OpAsmParser &parser,
+ IntegerAttr &m, IntegerAttr &n,
+ IntegerAttr &k) {
+ SmallVector<int64_t, 3> dimensions;
+ if (parser.parseDimensionList(dimensions, false, false))
+ return failure();
+ if (dimensions.size() != 3)
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected 3 dimensions in MNK dimension list";
- Type sourceAElemType = sourceVectorAType.getElementType();
- Type sourceBElemType = sourceVectorBType.getElementType();
- Type destElemType = destVectorType.getElementType();
+ m = parser.getBuilder().getI32IntegerAttr(dimensions[0]);
+ n = parser.getBuilder().getI32IntegerAttr(dimensions[1]);
+ k = parser.getBuilder().getI32IntegerAttr(dimensions[2]);
+ return success();
+}
- if (sourceVectorAType.getNumElements() !=
- sourceVectorBType.getNumElements()) {
+LogicalResult WMMAOp::verify() {
+ auto sourceAType = cast<VectorType>(getSourceA().getType());
+ auto sourceBType = cast<VectorType>(getSourceB().getType());
+ auto destType = cast<VectorType>(getDestC().getType());
+
+ Type sourceAElemType = sourceAType.getElementType();
+ Type sourceBElemType = sourceBType.getElementType();
+ if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
return emitOpError("source vectors have
diff erent lengths: ")
- << sourceVectorAType << " vs. " << sourceVectorBType;
+ << sourceAType << " vs. " << sourceBType;
}
- bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
- bool isSrcFloat =
- isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
- sourceAElemType);
-
- if (isDestFloat && !isSrcFloat) {
- return emitOpError("Expected float sources with float destination");
- }
+ bool isDestFloat = destType.getElementType().isFloat();
+ bool isSrcFloat = sourceAElemType.isFloat();
- if (!isDestFloat && isSrcFloat) {
- return emitOpError("Expected int sources with int destination");
- }
+ if (isDestFloat && !isSrcFloat)
+ return emitOpError("expected float sources with float destination");
+ if (!isDestFloat && isSrcFloat)
+ return emitOpError("expected int sources with int destination");
- if (sourceAElemType != sourceBElemType &&
- !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
- isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
+ if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
return emitOpError(
"source element types much match (except for fp8) but have ")
<< sourceAType << " and " << sourceBType;
}
+
+ if (!sourceAElemType.isInteger(4) && getK() != 16) {
+ return emitOpError("K dimension must be 16 for source element type ")
+ << sourceAElemType;
+ }
return success();
}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
similarity index 59%
rename from mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
rename to mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
index 638a7c3f8c1c5..d1301d0089220 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
@@ -1,35 +1,36 @@
-// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s
+
// CHECK-LABEL: @wmma_to_rocdl
func.func @wmma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : vector<4xf32>,
%arg3 : vector<16xbf16>, %arg4 : vector<8xf16>, %arg5 : vector<8xbf16>,
%arg6 : vector<16xi8>, %arg7 : vector<8xi32>, %arg8 : vector<4xi32>,
%arg9 : vector<16xui8>, %arg10 : vector<16xi4>, %arg11 : vector<8xi4>) {
// CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
- amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32>
+ amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32>
// CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32>
- amdgpu.wmma %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32>
+ amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32>
// CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xf32>) -> vector<8xf32>
- amdgpu.wmma %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
+ amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
// CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<4xf32>) -> vector<4xf32>
- amdgpu.wmma %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32>
+ amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32>
// CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
- amdgpu.wmma %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
+ amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
// CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16>
- amdgpu.wmma %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
+ amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
// CHECK: %[[raw_bf16x16:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16>
// CHECK-NEXT: llvm.bitcast %[[raw_bf16x16]] : vector<16xi16> to vector<16xbf16>
- amdgpu.wmma %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
+ amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
// CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16>
// CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16>
- amdgpu.wmma %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
+ amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
// CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32>
- amdgpu.wmma %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32>
+ amdgpu.wmma 16x16x16 %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32>
// CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32>
- amdgpu.wmma %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32>
+ amdgpu.wmma 16x16x16 %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32>
// CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
- amdgpu.wmma %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32>
+ amdgpu.wmma 16x16x16 %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32>
// CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
- amdgpu.wmma %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32>
+ amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32>
func.return
}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
index 94a1b78d5f040..b897323340402 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 --allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1200 --allow-unregistered-dialect | FileCheck %s
// CHECK-LABEL: @wmma_to_rocdl
func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>,
%arg2 : vector<8xf32>, %arg3 : vector<4xf32>,
@@ -9,60 +9,60 @@ func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>,
%arg12 : vector<8xi32>, %arg13 : vector<4xi32>,
%arg14 : vector<16xi4>, %arg15 : vector<8xi4>, %arg16 : vector<4xi4>) {
// CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<8xf32>) -> vector<8xf32>
- amdgpu.wmma %arg0 * %arg0 + %arg2 : vector<8xf16>, vector<8xf16>, vector<8xf32>
+ amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg2 : vector<8xf16>, vector<8xf16>, vector<8xf32>
// CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf32>) -> vector<4xf32>
- amdgpu.wmma %arg1 * %arg1 + %arg3 : vector<4xf16>, vector<4xf16>, vector<4xf32>
+ amdgpu.wmma 16x16x16 %arg1 * %arg1 + %arg3 : vector<4xf16>, vector<4xf16>, vector<4xf32>
// CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<8xf32>) -> vector<8xf32>
- amdgpu.wmma %arg4 * %arg4 + %arg2 : vector<8xbf16>, vector<8xbf16>, vector<8xf32>
+ amdgpu.wmma 16x16x16 %arg4 * %arg4 + %arg2 : vector<8xbf16>, vector<8xbf16>, vector<8xf32>
// CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xf32>) -> vector<4xf32>
- amdgpu.wmma %arg5 * %arg5 + %arg3 : vector<4xbf16>, vector<4xbf16>, vector<4xf32>
+ amdgpu.wmma 16x16x16 %arg5 * %arg5 + %arg3 : vector<4xbf16>, vector<4xbf16>, vector<4xf32>
// CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<8xf16>, i1) -> vector<8xf16>
- amdgpu.wmma %arg0 * %arg0 + %arg0 : vector<8xf16>, vector<8xf16>, vector<8xf16>
+ amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg0 : vector<8xf16>, vector<8xf16>, vector<8xf16>
// CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf16>, i1) -> vector<4xf16>
- amdgpu.wmma %arg1 * %arg1 + %arg1 : vector<4xf16>, vector<4xf16>, vector<4xf16>
+ amdgpu.wmma 16x16x16 %arg1 * %arg1 + %arg1 : vector<4xf16>, vector<4xf16>, vector<4xf16>
// CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<8xi16>, i1) -> vector<8xi16>
// CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16>
- amdgpu.wmma %arg4 * %arg4 + %arg4 : vector<8xbf16>, vector<8xbf16>, vector<8xbf16>
+ amdgpu.wmma 16x16x16 %arg4 * %arg4 + %arg4 : vector<8xbf16>, vector<8xbf16>, vector<8xbf16>
// CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xi16>, i1) -> vector<4xi16>
- amdgpu.wmma %arg5 * %arg5 + %arg5 : vector<4xbf16>, vector<4xbf16>, vector<4xbf16>
+ amdgpu.wmma 16x16x16 %arg5 * %arg5 + %arg5 : vector<4xbf16>, vector<4xbf16>, vector<4xbf16>
// CHECK: rocdl.wmma.f32.16x16x16.fp8_fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
- amdgpu.wmma %arg6 * %arg6 + %arg2 : vector<8xf8E4M3FN>, vector<8xf8E4M3FN>, vector<8xf32>
+ amdgpu.wmma 16x16x16 %arg6 * %arg6 + %arg2 : vector<8xf8E4M3FN>, vector<8xf8E4M3FN>, vector<8xf32>
// CHECK: rocdl.wmma.f32.16x16x16.fp8_fp8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32>
- amdgpu.wmma %arg7 * %arg7 + %arg3 : vector<4xf8E4M3FN>, vector<4xf8E4M3FN>, vector<4xf32>
+ amdgpu.wmma 16x16x16 %arg7 * %arg7 + %arg3 : vector<4xf8E4M3FN>, vector<4xf8E4M3FN>, vector<4xf32>
// CHECK: rocdl.wmma.f32.16x16x16.fp8_bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
- amdgpu.wmma %arg6 * %arg8 + %arg2 : vector<8xf8E4M3FN>, vector<8xf8E5M2>, vector<8xf32>
+ amdgpu.wmma 16x16x16 %arg6 * %arg8 + %arg2 : vector<8xf8E4M3FN>, vector<8xf8E5M2>, vector<8xf32>
// CHECK: rocdl.wmma.f32.16x16x16.fp8_bf8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32>
- amdgpu.wmma %arg7 * %arg9 + %arg3 : vector<4xf8E4M3FN>, vector<4xf8E5M2>, vector<4xf32>
+ amdgpu.wmma 16x16x16 %arg7 * %arg9 + %arg3 : vector<4xf8E4M3FN>, vector<4xf8E5M2>, vector<4xf32>
// CHECK: rocdl.wmma.f32.16x16x16.bf8_bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
- amdgpu.wmma %arg8 * %arg8 + %arg2 : vector<8xf8E5M2>, vector<8xf8E5M2>, vector<8xf32>
+ amdgpu.wmma 16x16x16 %arg8 * %arg8 + %arg2 : vector<8xf8E5M2>, vector<8xf8E5M2>, vector<8xf32>
// CHECK: rocdl.wmma.f32.16x16x16.bf8_bf8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32>
- amdgpu.wmma %arg9 * %arg9 + %arg3 : vector<4xf8E5M2>, vector<4xf8E5M2>, vector<4xf32>
+ amdgpu.wmma 16x16x16 %arg9 * %arg9 + %arg3 : vector<4xf8E5M2>, vector<4xf8E5M2>, vector<4xf32>
// CHECK: rocdl.wmma.f32.16x16x16.bf8_fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
- amdgpu.wmma %arg8 * %arg6 + %arg2 : vector<8xf8E5M2>, vector<8xf8E4M3FN>, vector<8xf32>
+ amdgpu.wmma 16x16x16 %arg8 * %arg6 + %arg2 : vector<8xf8E5M2>, vector<8xf8E4M3FN>, vector<8xf32>
// CHECK: rocdl.wmma.f32.16x16x16.bf8_fp8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32>
- amdgpu.wmma %arg9 * %arg7 + %arg3 : vector<4xf8E5M2>, vector<4xf8E4M3FN>, vector<4xf32>
+ amdgpu.wmma 16x16x16 %arg9 * %arg7 + %arg3 : vector<4xf8E5M2>, vector<4xf8E4M3FN>, vector<4xf32>
// CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
- amdgpu.wmma %arg10 * %arg10 + %arg12 {clamp} : vector<8xi8>, vector<8xi8>, vector<8xi32>
+ amdgpu.wmma 16x16x16 %arg10 * %arg10 + %arg12 {clamp} : vector<8xi8>, vector<8xi8>, vector<8xi32>
// CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
- amdgpu.wmma %arg11 * %arg11 + %arg13 {unsignedA, unsignedB, clamp}: vector<4xi8>, vector<4xi8>, vector<4xi32>
+ amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg13 {unsignedA, unsignedB, clamp}: vector<4xi8>, vector<4xi8>, vector<4xi32>
// CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
- amdgpu.wmma %arg14 * %arg14 + %arg12 {clamp} : vector<16xi4>, vector<16xi4>, vector<8xi32>
+ amdgpu.wmma 16x16x32 %arg14 * %arg14 + %arg12 {clamp} : vector<16xi4>, vector<16xi4>, vector<8xi32>
// CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
- amdgpu.wmma %arg15 * %arg15 + %arg13 {clamp} : vector<8xi4>, vector<8xi4>, vector<4xi32>
+ amdgpu.wmma 16x16x32 %arg15 * %arg15 + %arg13 {clamp} : vector<8xi4>, vector<8xi4>, vector<4xi32>
// CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<8xi32>, i1) -> vector<8xi32>
- amdgpu.wmma %arg15 * %arg15 + %arg12 {clamp} : vector<8xi4>, vector<8xi4>, vector<8xi32>
+ amdgpu.wmma 16x16x16 %arg15 * %arg15 + %arg12 {clamp} : vector<8xi4>, vector<8xi4>, vector<8xi32>
// CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
- amdgpu.wmma %arg16 * %arg16 + %arg13 {clamp} : vector<4xi4>, vector<4xi4>, vector<4xi32>
+ amdgpu.wmma 16x16x16 %arg16 * %arg16 + %arg13 {clamp} : vector<4xi4>, vector<4xi4>, vector<4xi32>
func.return
}
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index a8256b16ed8a1..6a2518a40cc99 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -120,9 +120,49 @@ func.func @no_negation(%a: f32, %b: f32, %c: vector<32xf32>) -> vector<32xf32> {
// -----
-func.func @wmma(%arg0 : vector<16xf16>, %arg1 : vector<8xi32>) -> vector<8xi32> {
- // expected-error at +1 {{'amdgpu.wmma' op Expected int sources with int destination}}
- %0 = amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xi32>
+func.func @wmma_f16_i32(%arg0 : vector<16xf16>, %arg1 : vector<8xi32>) -> vector<8xi32> {
+ // expected-error at +1 {{'amdgpu.wmma' op expected int sources with int destination}}
+ %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xi32>
+ func.return %0 : vector<8xi32>
+}
+
+// -----
+
+func.func @wmma_i16_f32(%arg0 : vector<16xi8>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+ // expected-error at +1 {{'amdgpu.wmma' op expected float sources with float destination}}
+ %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xf32>
+ func.return %0 : vector<8xf32>
+}
+
+// -----
+
+func.func @wmma_no_k_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
+ // expected-error at +1 {{'amdgpu.wmma' expected 3 dimensions in MNK dimension list}}
+ %0 = amdgpu.wmma 16x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
+ func.return %0 : vector<8xi32>
+}
+
+// -----
+
+func.func @wmma_wrong_m_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
+ // expected-error at +1 {{'amdgpu.wmma' op attribute 'm' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16}}}
+ %0 = amdgpu.wmma 32x16x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
+ func.return %0 : vector<8xi32>
+}
+
+// -----
+
+func.func @wmma_wrong_n_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
+ // expected-error at +1 {{'amdgpu.wmma' op attribute 'n' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16}}}
+ %0 = amdgpu.wmma 16x32x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
+ func.return %0 : vector<8xi32>
+}
+
+// -----
+
+func.func @wmma_wrong_k_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
+ // expected-error at +1 {{'amdgpu.wmma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}}
+ %0 = amdgpu.wmma 16x16x24 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
func.return %0 : vector<8xi32>
}
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index f9c6899dadfc1..a185eb612c9ac 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -565,13 +565,20 @@ func.func @mfma(%arg0 : f32, %arg1 : vector<32xf32>) -> vector<32xf32> {
func.return %0 : vector<32xf32>
}
-// CHECK-LABEL: func @wmma
-func.func @wmma(%arg0 : vector<16xf16>, %arg1 : vector<8xf16>) -> vector<8xf16> {
- // CHECK: amdgpu.wmma
- %0 = amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf16>
+// CHECK-LABEL: func @wmma_f16_16x16x16_f16
+func.func @wmma_f16_16x16x16_f16(%arg0 : vector<16xf16>, %arg1 : vector<8xf16>) -> vector<8xf16> {
+ // CHECK: amdgpu.wmma 16x16x16
+ %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf16>
func.return %0 : vector<8xf16>
}
+// CHECK-LABEL: func @wmma_i32_16x16x32_i4
+func.func @wmma_i32_16x16x32_i4(%arg0 : vector<16xi4>, %arg1 : vector<8xi32>) -> vector<8xi32> {
+ // CHECK: amdgpu.wmma 16x16x32
+ %0 = amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg1 : vector<16xi4>, vector<16xi4>, vector<8xi32>
+ func.return %0 : vector<8xi32>
+}
+
// CHECK-LABEL: func @swizzle_bitmode
func.func @swizzle_bitmode(%arg0 : f32) -> f32 {
// CHECK: amdgpu.swizzle_bitmode
More information about the Mlir-commits
mailing list