[Mlir-commits] [mlir] [mlir][AMDGPU] Add scaled wmma ops for gfx1250 (PR #169854)
Justin Rosner
llvmlistbot at llvm.org
Thu Nov 27 11:44:46 PST 2025
https://github.com/justinrosner updated https://github.com/llvm/llvm-project/pull/169854
>From 32b0c2c160607874949a87ecf38fa68bd118bd86 Mon Sep 17 00:00:00 2001
From: Justin Rosner <justin.rosner at amd.com>
Date: Thu, 27 Nov 2025 19:25:56 +0000
Subject: [PATCH 1/3] Add scaled WMMA to AMDGPU
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 60 ++++++++
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 128 ++++++++++++++++--
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 71 ++++++++++
3 files changed, 244 insertions(+), 15 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index e07c72b839e7c..a2201d3127370 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -951,6 +951,13 @@ def MFMAOutTypes : AnyTypeOf<[F64,
def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN]>,
VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
+
+// scaled_wmma
+def ScaledWMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[64], [F8E5M2, F8E4M3FN]>,
+ VectorOfLengthAndType<[64], [F6E2M3FN, F6E3M2FN]>,
+ VectorOfLengthAndType<[64, 128], [F4E2M1FN]>]>;
+def ScaledWMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32]>]>;
+
// wmma
def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[2], [F32]>,
VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
@@ -1218,6 +1225,59 @@ def AMDGPU_ScaledMFMAOp :
let hasCanonicalizer = 1;
}
+def AMDGPU_ScaledWMMAOp :
+ AMDGPU_Op<"scaled_wmma", [AllTypesMatch<["destC", "destD"]>,
+ Pure]>,
+ Arguments<(ins
+ ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$m,
+ ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n,
+ ConfinedAttr<I32Attr, [IntIsOneOf<[128]>]>:$k,
+ ScaledWMMAInTypes:$sourceA,
+ ScaledWMMAInTypes:$sourceB,
+ ScaledWMMAOutTypes:$destC,
+ AnyTypeOf<[I32, I64]>:$scaleA,
+ AnyTypeOf<[I32, I64]>:$scaleB,
+ DefaultValuedAttr<I32Attr, "0">:$scaleAType,
+ DefaultValuedAttr<I32Attr, "0">:$fmtScaleA,
+ DefaultValuedAttr<I32Attr, "0">:$scaleBType,
+ DefaultValuedAttr<I32Attr, "0">:$fmtScaleB
+ )>,
+ Results<(outs ScaledWMMAOutTypes: $destD)> {
+ let summary = "MLIR wrapper for RDNA scaled wmma instructions";
+ let description = [{
+ The `amdgpu.scaled_wmma` op is an MLIR wrapper around intrinsics for scaled
+ `wmma` instructions in the RDNA architecture. These instructions perform
+ matrix multiplication with per-block scaling of inputs, supporting fp4, fp6,
+ and fp8 data formats.
+
+ The scale instructions support two tile sizes:
+ - 16x16x128 with mixed f8/f6/f4 formats (output: vector<4xf32>)
+ - 32x16x128 with f4 format only (output: vector<8xf32>)
+
+ The `scaleA` and `scaleB` parameters are scale exponents that can be either
+ i32 (for wmma.scale) or i64 (for wmma.scale16) to support per-block scaling.
+
+ Optional modifiers:
+ - `scaleAType`, `scaleBType`: Type of scale parameter
+ - `fmtScaleA`, `fmtScaleB`: Format of scale parameter
+
+ Example:
+ ```mlir
+ %0 = amdgpu.scaled_wmma (%sa * %matA) * (%sb * %matB) + %matC
+ { m = 16, n = 16, k = 128 } : i32, vector<64xf8E4M3FN>, i32, vector<64xf8E4M3FN>, vector<4xf32>
+
+ %1 = amdgpu.scaled_wmma (%sc * %matD) * (%sd * %matE) + %matF
+ { m = 32, n = 16, k = 128 } : i32, vector<128xf4E2M1FN>, i32, vector<64xf4E2M1FN>, vector<8xf32>
+ ```
+ }];
+ let assemblyFormat = [{
+ `(` $scaleA `*` $sourceA `)` `*` `(` $scaleB `*` $sourceB `)` `+` $destC
+ attr-dict
+ `:` type($scaleA) `,` type($sourceA) `,` type($scaleB) `,` type($sourceB) `,` type($destC)
+ }];
+ let hasVerifier = 1;
+}
+
def AMDGPU_MakeDmaBaseOp :
AMDGPU_Op<"make_dma_base", [Pure, AttrSizedOperandSegments]>,
Arguments<(ins
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index b9a5e7d7f6eac..02f0e14791bf2 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -612,8 +612,8 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
} // namespace
-/// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
-/// and LLVM AMDGPU intrinsics convention.
+/// Pack small float vector operands (fp4/fp6/fp8/bf16) into the format
+/// expected by scaled matrix multiply intrinsics (MFMA/WMMA).
///
/// Specifically:
/// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
@@ -627,9 +627,9 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
/// Note that the type of `input` has already been LLVM type converted:
/// therefore 8-bit and smaller floats are represented as their corresponding
/// `iN` integers.
-static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
- Location loc, Value input,
- bool allowBf16 = true) {
+static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter,
+ Location loc, Value input,
+ bool allowBf16 = true) {
Type inputType = input.getType();
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
if (vectorType.getElementType().isBF16() && !allowBf16)
@@ -918,7 +918,7 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
return std::nullopt;
}
-static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {
+static std::optional<uint32_t> smallFloatTypeToFormatCode(Type mlirElemType) {
return llvm::TypeSwitch<Type, std::optional<uint32_t>>(mlirElemType)
.Case([](Float8E4M3FNType) { return 0u; })
.Case([](Float8E5M2Type) { return 1u; })
@@ -947,8 +947,8 @@ mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
if (!isa<Float32Type>(destType))
return std::nullopt;
- std::optional<uint32_t> aTypeCode = mfmaTypeSelectCode(aType);
- std::optional<uint32_t> bTypeCode = mfmaTypeSelectCode(bType);
+ std::optional<uint32_t> aTypeCode = smallFloatTypeToFormatCode(aType);
+ std::optional<uint32_t> bTypeCode = smallFloatTypeToFormatCode(bType);
if (!aTypeCode || !bTypeCode)
return std::nullopt;
@@ -1212,11 +1212,12 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
}();
OperationState loweredOp(loc, intrinsicName);
loweredOp.addTypes(intrinsicOutType);
- loweredOp.addOperands({convertMFMAVectorOperand(
- rewriter, loc, adaptor.getSourceA(), allowBf16),
- convertMFMAVectorOperand(
- rewriter, loc, adaptor.getSourceB(), allowBf16),
- adaptor.getDestC()});
+ loweredOp.addOperands(
+ {packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA(),
+ allowBf16),
+ packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB(),
+ allowBf16),
+ adaptor.getDestC()});
if (isScaled) {
Value zero = createI32Constant(rewriter, loc, 0);
auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
@@ -1261,8 +1262,8 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
OperationState loweredOp(loc, intrinsicName);
loweredOp.addTypes(intrinsicOutType);
loweredOp.addOperands(
- {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
- convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
+ {packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA()),
+ packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB()),
adaptor.getDestC()});
Value scalesIdxA =
createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
@@ -1363,6 +1364,103 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
}
};
+struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
+ ScaledWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto outType =
+ typeConverter->convertType<VectorType>(op.getDestD().getType());
+ if (!outType)
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+ if (chipset < Chipset(12, 5, 0))
+ return op->emitOpError("WMMA scale only supported on gfx1250+");
+
+ int64_t m = op.getM();
+ int64_t n = op.getN();
+ int64_t k = op.getK();
+
+ Type aElemType = getElementTypeOrSelf(op.getSourceA().getType());
+ Type bElemType = getElementTypeOrSelf(op.getSourceB().getType());
+
+ std::optional<uint32_t> aFmtCode = smallFloatTypeToFormatCode(aElemType);
+ std::optional<uint32_t> bFmtCode = smallFloatTypeToFormatCode(bElemType);
+
+ if (!aFmtCode || !bFmtCode)
+ return op.emitOpError("unsupported element types for scaled_wmma");
+
+ // Determine which intrinsic to use based on dimensions and scale type
+ StringRef intrinsicName;
+ bool isScale16 = adaptor.getScaleA().getType().isInteger(64);
+ bool is32x16 = (m == 32 && n == 16 && k == 128);
+
+ if (m == 16 && n == 16 && k == 128) {
+ intrinsicName = isScale16
+ ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
+ : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
+ } else if (is32x16) {
+ intrinsicName = isScale16
+ ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
+ : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
+ } else {
+ return op.emitOpError("unsupported scaled_wmma dimensions: ")
+ << m << "x" << n << "x" << k;
+ }
+
+ SmallVector<NamedAttribute, 8> attrs;
+
+ // The f4 variant does not have fmtA and fmtB attributes
+ if (!is32x16) {
+ attrs.push_back(rewriter.getNamedAttr("fmtA",
+ rewriter.getI32IntegerAttr(*aFmtCode)));
+ attrs.push_back(rewriter.getNamedAttr("fmtB",
+ rewriter.getI32IntegerAttr(*bFmtCode)));
+ }
+
+ // Add modifier attributes - modC and reuse flags default to 0/false
+ attrs.push_back(rewriter.getNamedAttr("reuseA",
+ rewriter.getBoolAttr(false)));
+ attrs.push_back(rewriter.getNamedAttr("reuseB",
+ rewriter.getBoolAttr(false)));
+ attrs.push_back(rewriter.getNamedAttr("modC",
+ rewriter.getI16IntegerAttr(0)));
+
+ // Scale type/format parameters from the operation
+ attrs.push_back(rewriter.getNamedAttr("scaleAType",
+ rewriter.getI32IntegerAttr(op.getScaleAType())));
+ attrs.push_back(rewriter.getNamedAttr("fmtScaleA",
+ rewriter.getI32IntegerAttr(op.getFmtScaleA())));
+ attrs.push_back(rewriter.getNamedAttr("scaleBType",
+ rewriter.getI32IntegerAttr(op.getScaleBType())));
+ attrs.push_back(rewriter.getNamedAttr("fmtScaleB",
+ rewriter.getI32IntegerAttr(op.getFmtScaleB())));
+
+ // Convert typed float vectors to packed i32 format if needed
+ Value sourceA =
+ packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA());
+ Value sourceB =
+ packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB());
+
+ // Create the intrinsic call
+ OperationState loweredOp(loc, intrinsicName);
+ loweredOp.addTypes(outType);
+ loweredOp.addOperands({sourceA, sourceB, adaptor.getDestC(),
+ adaptor.getScaleA(), adaptor.getScaleB()});
+ loweredOp.addAttributes(attrs);
+
+ Operation *lowered = rewriter.create(loweredOp);
+ rewriter.replaceOp(op, lowered->getResults());
+
+ return success();
+ }
+};
+
struct TransposeLoadOpLowering
: public ConvertOpToLLVMPattern<TransposeLoadOp> {
TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index cdc10c60a42ae..87bd1903290ae 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -442,6 +442,77 @@ LogicalResult WMMAOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// ScaledWMMAOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ScaledWMMAOp::verify() {
+ auto sourceAType = cast<VectorType>(getSourceA().getType());
+ auto sourceBType = cast<VectorType>(getSourceB().getType());
+ auto destType = cast<VectorType>(getDestC().getType());
+
+ // Validate output type is F32
+ if (!destType.getElementType().isF32())
+ return emitOpError("destination must have f32 element type");
+
+ // Validate source element types are small floats (fp4/fp6/fp8)
+ Type aElemType = sourceAType.getElementType();
+ Type bElemType = sourceBType.getElementType();
+
+ bool aIsSmallFloat = aElemType.isFloat(4) || aElemType.isFloat(6) ||
+ aElemType.isFloat(8);
+ bool bIsSmallFloat = bElemType.isFloat(4) || bElemType.isFloat(6) ||
+ bElemType.isFloat(8);
+
+ if (!aIsSmallFloat || !bIsSmallFloat)
+ return emitOpError("source operands must have small float element types "
+ "(fp4/fp6/fp8)");
+
+ // Validate scale types match (both i32 or both i64)
+ Type scaleAType = getScaleA().getType();
+ Type scaleBType = getScaleB().getType();
+ if (scaleAType != scaleBType)
+ return emitOpError("scaleA and scaleB must have the same type");
+
+ // Validate vector lengths based on dimensions
+ int64_t m = getM();
+ int64_t aLen = sourceAType.getNumElements();
+ int64_t bLen = sourceBType.getNumElements();
+ int64_t expectedOutLen = (m == 16) ? 4 : 8;
+
+ if (destType.getNumElements() != expectedOutLen)
+ return emitOpError("expected output vector of length " +
+ Twine(expectedOutLen) + " but got " +
+ Twine(destType.getNumElements()));
+
+ if (m == 16) {
+ // For 16×16×128: both A and B must be 64 elements
+ if (aLen != 64)
+ return emitOpError(
+ "for 16x16x128, sourceA must have 64 elements but got " +
+ Twine(aLen));
+ if (bLen != 64)
+ return emitOpError(
+ "for 16x16x128, sourceB must have 64 elements but got " +
+ Twine(bLen));
+ } else { // m == 32
+ // For 32×16×128: only fp4 is supported, A is 128, B is 64
+ if (!aElemType.isFloat(4))
+ return emitOpError("32x16x128 only supports fp4 element types");
+
+ if (aLen != 128)
+ return emitOpError(
+ "for 32x16x128, sourceA must have 128 elements but got " +
+ Twine(aLen));
+ if (bLen != 64)
+ return emitOpError(
+ "for 32x16x128, sourceB must have 64 elements but got " +
+ Twine(bLen));
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// MFMAOp
//===----------------------------------------------------------------------===//
>From 8750acce924f204b8a3381166ec2fe2e92dc0d76 Mon Sep 17 00:00:00 2001
From: Justin Rosner <justin.rosner at amd.com>
Date: Thu, 27 Nov 2025 19:36:08 +0000
Subject: [PATCH 2/3] Add LIT tests
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 3 +-
.../AMDGPUToROCDL/wmma-gfx1250.mlir | 73 +++++++++++++++++++
mlir/test/Dialect/AMDGPU/ops.mlir | 26 +++++++
3 files changed, 101 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 02f0e14791bf2..f4034f44d06b8 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -2427,7 +2427,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
ROCDL::RawPtrBufferAtomicCmpSwap>,
AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
- WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPacked816OpLowering,
+ WMMAOpLowering, ScaledWMMAOpLowering, ExtPackedFp8OpLowering,
+ ScaledExtPacked816OpLowering,
ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
GatherToLDSOpLowering, TransposeLoadOpLowering,
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
index 37259f6ed06eb..d187e62484059 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
@@ -89,6 +89,79 @@ func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
return
}
+// CHECK-LABEL: @wmma_scale_16x16x128_fp8
+func.func @wmma_scale_16x16x128_fp8(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
+ %arg2 : vector<4xf32>, %arg3 : i32, %arg4 : i32) {
+ // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 : (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %0 = amdgpu.scaled_wmma (%arg3 * %arg0) * (%arg4 * %arg0) + %arg2
+ { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf8E4M3FN>, i32, vector<64xf8E4M3FN>, vector<4xf32>
+
+ // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 {fmtA = 1 : i32, fmtB = 1 : i32} : (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %1 = amdgpu.scaled_wmma (%arg3 * %arg1) * (%arg4 * %arg1) + %arg2
+ { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf8E5M2>, i32, vector<64xf8E5M2>, vector<4xf32>
+
+ func.return
+}
+
+// CHECK-LABEL: @wmma_scale_16x16x128_fp6
+func.func @wmma_scale_16x16x128_fp6(%arg0 : vector<64xf6E2M3FN>, %arg1 : vector<64xf6E3M2FN>,
+ %arg2 : vector<4xf32>, %arg3 : i32, %arg4 : i32) {
+ // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 {fmtA = 2 : i32, fmtB = 2 : i32} : (vector<12xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %0 = amdgpu.scaled_wmma (%arg3 * %arg0) * (%arg4 * %arg0) + %arg2
+ { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf6E2M3FN>, i32, vector<64xf6E2M3FN>, vector<4xf32>
+
+ // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 {fmtA = 3 : i32, fmtB = 3 : i32} : (vector<12xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %1 = amdgpu.scaled_wmma (%arg3 * %arg1) * (%arg4 * %arg1) + %arg2
+ { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf6E3M2FN>, i32, vector<64xf6E3M2FN>, vector<4xf32>
+
+ func.return
+}
+
+// CHECK-LABEL: @wmma_scale_16x16x128_mixed
+func.func @wmma_scale_16x16x128_mixed(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf6E2M3FN>,
+ %arg2 : vector<64xf4E2M1FN>, %arg3 : vector<4xf32>,
+ %arg4 : i32, %arg5 : i32) {
+ // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg3, %arg4, %arg5 {fmtB = 2 : i32} : (vector<16xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %0 = amdgpu.scaled_wmma (%arg4 * %arg0) * (%arg5 * %arg1) + %arg3
+ { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf8E4M3FN>, i32, vector<64xf6E2M3FN>, vector<4xf32>
+
+ // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg3, %arg4, %arg5 {fmtA = 2 : i32, fmtB = 4 : i32} : (vector<12xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %1 = amdgpu.scaled_wmma (%arg4 * %arg1) * (%arg5 * %arg2) + %arg3
+ { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf6E2M3FN>, i32, vector<64xf4E2M1FN>, vector<4xf32>
+
+ func.return
+}
+
+// CHECK-LABEL: @wmma_scale16_16x16x128_fp8
+func.func @wmma_scale16_16x16x128_fp8(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
+ %arg2 : vector<4xf32>, %arg3 : i64, %arg4 : i64) {
+ // CHECK: rocdl.wmma.scale16.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 : (vector<16xi32>, vector<16xi32>, vector<4xf32>, i64, i64) -> vector<4xf32>
+ %0 = amdgpu.scaled_wmma (%arg3 * %arg0) * (%arg4 * %arg0) + %arg2
+ { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i64, vector<64xf8E4M3FN>, i64, vector<64xf8E4M3FN>, vector<4xf32>
+
+ func.return
+}
+
+// CHECK-LABEL: @wmma_scale_32x16x128_fp4
+func.func @wmma_scale_32x16x128_fp4(%arg0 : vector<128xf4E2M1FN>, %arg1 : vector<64xf4E2M1FN>,
+ %arg2 : vector<8xf32>, %arg3 : i32, %arg4 : i32) {
+ // CHECK: rocdl.wmma.scale.f32.32x16x128.f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 : (vector<16xi32>, vector<8xi32>, vector<8xf32>, i32, i32) -> vector<8xf32>
+ %0 = amdgpu.scaled_wmma (%arg3 * %arg0) * (%arg4 * %arg1) + %arg2
+ { m = 32 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<128xf4E2M1FN>, i32, vector<64xf4E2M1FN>, vector<8xf32>
+
+ func.return
+}
+
+// CHECK-LABEL: @wmma_scale16_32x16x128_fp4
+func.func @wmma_scale16_32x16x128_fp4(%arg0 : vector<128xf4E2M1FN>, %arg1 : vector<64xf4E2M1FN>,
+ %arg2 : vector<8xf32>, %arg3 : i64, %arg4 : i64) {
+ // CHECK: rocdl.wmma.scale16.f32.32x16x128.f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 : (vector<16xi32>, vector<8xi32>, vector<8xf32>, i64, i64) -> vector<8xf32>
+ %0 = amdgpu.scaled_wmma (%arg3 * %arg0) * (%arg4 * %arg1) + %arg2
+ { m = 32 : i32, n = 16 : i32, k = 128 : i32 } : i64, vector<128xf4E2M1FN>, i64, vector<64xf4E2M1FN>, vector<8xf32>
+
+ func.return
+}
+
// -----
func.func @wmma_unsupported_k(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) {
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 653f9f64d24f4..f1492b2d4d14e 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -697,3 +697,29 @@ func.func @make_dma_base(%idx: index, %mem: memref<8xi32>, %smem: memref<8xi32,
func.return
}
+// CHECK-LABEL: func @wmma_scale
+func.func @wmma_scale(%fp8_src: vector<64xf8E4M3FN>, %bf8_src: vector<64xf8E5M2>,
+ %fp6_src: vector<64xf6E2M3FN>, %fp4_src_a: vector<128xf4E2M1FN>,
+ %fp4_src_b: vector<64xf4E2M1FN>,
+ %dst0: vector<4xf32>, %dst1: vector<8xf32>,
+ %scale32: i32, %scale64: i64) {
+ // CHECK: amdgpu.scaled_wmma({{.*}} * {{.*}}) * ({{.*}} * {{.*}}) + {{.*}} {k = 128 : i32, m = 16 : i32, n = 16 : i32} : i32, vector<64xf8E4M3FN>, i32, vector<64xf8E4M3FN>, vector<4xf32>
+ %0 = amdgpu.scaled_wmma (%scale32 * %fp8_src) * (%scale32 * %fp8_src) + %dst0
+ { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf8E4M3FN>, i32, vector<64xf8E4M3FN>, vector<4xf32>
+ // CHECK: amdgpu.scaled_wmma({{.*}} * {{.*}}) * ({{.*}} * {{.*}}) + {{.*}} {k = 128 : i32, m = 16 : i32, n = 16 : i32} : i32, vector<64xf8E5M2>, i32, vector<64xf8E5M2>, vector<4xf32>
+ %1 = amdgpu.scaled_wmma (%scale32 * %bf8_src) * (%scale32 * %bf8_src) + %dst0
+ { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf8E5M2>, i32, vector<64xf8E5M2>, vector<4xf32>
+ // CHECK: amdgpu.scaled_wmma({{.*}} * {{.*}}) * ({{.*}} * {{.*}}) + {{.*}} {k = 128 : i32, m = 16 : i32, n = 16 : i32} : i32, vector<64xf6E2M3FN>, i32, vector<64xf6E2M3FN>, vector<4xf32>
+ %2 = amdgpu.scaled_wmma (%scale32 * %fp6_src) * (%scale32 * %fp6_src) + %dst0
+ { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf6E2M3FN>, i32, vector<64xf6E2M3FN>, vector<4xf32>
+ // CHECK: amdgpu.scaled_wmma({{.*}} * {{.*}}) * ({{.*}} * {{.*}}) + {{.*}} {k = 128 : i32, m = 16 : i32, n = 16 : i32} : i32, vector<64xf4E2M1FN>, i32, vector<64xf6E2M3FN>, vector<4xf32>
+ %3 = amdgpu.scaled_wmma (%scale32 * %fp4_src_b) * (%scale32 * %fp6_src) + %dst0
+ { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf4E2M1FN>, i32, vector<64xf6E2M3FN>, vector<4xf32>
+ // CHECK: amdgpu.scaled_wmma({{.*}} * {{.*}}) * ({{.*}} * {{.*}}) + {{.*}} {k = 128 : i32, m = 16 : i32, n = 16 : i32} : i64, vector<64xf8E4M3FN>, i64, vector<64xf8E4M3FN>, vector<4xf32>
+ %4 = amdgpu.scaled_wmma (%scale64 * %fp8_src) * (%scale64 * %fp8_src) + %dst0
+ { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i64, vector<64xf8E4M3FN>, i64, vector<64xf8E4M3FN>, vector<4xf32>
+ // CHECK: amdgpu.scaled_wmma({{.*}} * {{.*}}) * ({{.*}} * {{.*}}) + {{.*}} {k = 128 : i32, m = 32 : i32, n = 16 : i32} : i32, vector<128xf4E2M1FN>, i32, vector<64xf4E2M1FN>, vector<8xf32>
+ %5 = amdgpu.scaled_wmma (%scale32 * %fp4_src_a) * (%scale32 * %fp4_src_b) + %dst1
+ { m = 32 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<128xf4E2M1FN>, i32, vector<64xf4E2M1FN>, vector<8xf32>
+ func.return
+}
>From e8e4aa1c7f1dc9d0421a4a907e8dc9b156387758 Mon Sep 17 00:00:00 2001
From: Justin Rosner <justin.rosner at amd.com>
Date: Thu, 27 Nov 2025 19:44:22 +0000
Subject: [PATCH 3/3] Clang-format
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 38 +++++-----
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 69 +++++++++----------
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 12 ++--
3 files changed, 57 insertions(+), 62 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index a2201d3127370..9d65130154010 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -953,9 +953,10 @@ def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN
def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
// scaled_wmma
-def ScaledWMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[64], [F8E5M2, F8E4M3FN]>,
- VectorOfLengthAndType<[64], [F6E2M3FN, F6E3M2FN]>,
- VectorOfLengthAndType<[64, 128], [F4E2M1FN]>]>;
+def ScaledWMMAInTypes
+ : AnyTypeOf<[VectorOfLengthAndType<[64], [F8E5M2, F8E4M3FN]>,
+ VectorOfLengthAndType<[64], [F6E2M3FN, F6E3M2FN]>,
+ VectorOfLengthAndType<[64, 128], [F4E2M1FN]>]>;
def ScaledWMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32]>]>;
// wmma
@@ -1225,24 +1226,19 @@ def AMDGPU_ScaledMFMAOp :
let hasCanonicalizer = 1;
}
-def AMDGPU_ScaledWMMAOp :
- AMDGPU_Op<"scaled_wmma", [AllTypesMatch<["destC", "destD"]>,
- Pure]>,
- Arguments<(ins
- ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$m,
- ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n,
- ConfinedAttr<I32Attr, [IntIsOneOf<[128]>]>:$k,
- ScaledWMMAInTypes:$sourceA,
- ScaledWMMAInTypes:$sourceB,
- ScaledWMMAOutTypes:$destC,
- AnyTypeOf<[I32, I64]>:$scaleA,
- AnyTypeOf<[I32, I64]>:$scaleB,
- DefaultValuedAttr<I32Attr, "0">:$scaleAType,
- DefaultValuedAttr<I32Attr, "0">:$fmtScaleA,
- DefaultValuedAttr<I32Attr, "0">:$scaleBType,
- DefaultValuedAttr<I32Attr, "0">:$fmtScaleB
- )>,
- Results<(outs ScaledWMMAOutTypes: $destD)> {
+def AMDGPU_ScaledWMMAOp
+ : AMDGPU_Op<"scaled_wmma", [AllTypesMatch<["destC", "destD"]>, Pure]>,
+ Arguments<(ins ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$m,
+ ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n,
+ ConfinedAttr<I32Attr, [IntIsOneOf<[128]>]>:$k,
+ ScaledWMMAInTypes:$sourceA, ScaledWMMAInTypes:$sourceB,
+ ScaledWMMAOutTypes:$destC, AnyTypeOf<[I32, I64]>:$scaleA,
+ AnyTypeOf<[I32, I64]>:$scaleB,
+ DefaultValuedAttr<I32Attr, "0">:$scaleAType,
+ DefaultValuedAttr<I32Attr, "0">:$fmtScaleA,
+ DefaultValuedAttr<I32Attr, "0">:$scaleBType,
+ DefaultValuedAttr<I32Attr, "0">:$fmtScaleB)>,
+ Results<(outs ScaledWMMAOutTypes:$destD)> {
let summary = "MLIR wrapper for RDNA scaled wmma instructions";
let description = [{
The `amdgpu.scaled_wmma` op is an MLIR wrapper around intrinsics for scaled
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index f4034f44d06b8..731e33c82b75f 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1212,12 +1212,11 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
}();
OperationState loweredOp(loc, intrinsicName);
loweredOp.addTypes(intrinsicOutType);
- loweredOp.addOperands(
- {packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA(),
- allowBf16),
- packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB(),
- allowBf16),
- adaptor.getDestC()});
+ loweredOp.addOperands({packSmallFloatVectorOperand(
+ rewriter, loc, adaptor.getSourceA(), allowBf16),
+ packSmallFloatVectorOperand(
+ rewriter, loc, adaptor.getSourceB(), allowBf16),
+ adaptor.getDestC()});
if (isScaled) {
Value zero = createI32Constant(rewriter, loc, 0);
auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
@@ -1401,13 +1400,14 @@ struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
bool is32x16 = (m == 32 && n == 16 && k == 128);
if (m == 16 && n == 16 && k == 128) {
- intrinsicName = isScale16
- ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
- : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
+ intrinsicName =
+ isScale16
+ ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
+ : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
} else if (is32x16) {
- intrinsicName = isScale16
- ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
- : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
+ intrinsicName =
+ isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
+ : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
} else {
return op.emitOpError("unsupported scaled_wmma dimensions: ")
<< m << "x" << n << "x" << k;
@@ -1417,29 +1417,29 @@ struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
// The f4 variant does not have fmtA and fmtB attributes
if (!is32x16) {
- attrs.push_back(rewriter.getNamedAttr("fmtA",
- rewriter.getI32IntegerAttr(*aFmtCode)));
- attrs.push_back(rewriter.getNamedAttr("fmtB",
- rewriter.getI32IntegerAttr(*bFmtCode)));
+ attrs.push_back(
+ rewriter.getNamedAttr("fmtA", rewriter.getI32IntegerAttr(*aFmtCode)));
+ attrs.push_back(
+ rewriter.getNamedAttr("fmtB", rewriter.getI32IntegerAttr(*bFmtCode)));
}
// Add modifier attributes - modC and reuse flags default to 0/false
- attrs.push_back(rewriter.getNamedAttr("reuseA",
- rewriter.getBoolAttr(false)));
- attrs.push_back(rewriter.getNamedAttr("reuseB",
- rewriter.getBoolAttr(false)));
- attrs.push_back(rewriter.getNamedAttr("modC",
- rewriter.getI16IntegerAttr(0)));
+ attrs.push_back(
+ rewriter.getNamedAttr("reuseA", rewriter.getBoolAttr(false)));
+ attrs.push_back(
+ rewriter.getNamedAttr("reuseB", rewriter.getBoolAttr(false)));
+ attrs.push_back(
+ rewriter.getNamedAttr("modC", rewriter.getI16IntegerAttr(0)));
// Scale type/format parameters from the operation
- attrs.push_back(rewriter.getNamedAttr("scaleAType",
- rewriter.getI32IntegerAttr(op.getScaleAType())));
- attrs.push_back(rewriter.getNamedAttr("fmtScaleA",
- rewriter.getI32IntegerAttr(op.getFmtScaleA())));
- attrs.push_back(rewriter.getNamedAttr("scaleBType",
- rewriter.getI32IntegerAttr(op.getScaleBType())));
- attrs.push_back(rewriter.getNamedAttr("fmtScaleB",
- rewriter.getI32IntegerAttr(op.getFmtScaleB())));
+ attrs.push_back(rewriter.getNamedAttr(
+ "scaleAType", rewriter.getI32IntegerAttr(op.getScaleAType())));
+ attrs.push_back(rewriter.getNamedAttr(
+ "fmtScaleA", rewriter.getI32IntegerAttr(op.getFmtScaleA())));
+ attrs.push_back(rewriter.getNamedAttr(
+ "scaleBType", rewriter.getI32IntegerAttr(op.getScaleBType())));
+ attrs.push_back(rewriter.getNamedAttr(
+ "fmtScaleB", rewriter.getI32IntegerAttr(op.getFmtScaleB())));
// Convert typed float vectors to packed i32 format if needed
Value sourceA =
@@ -2428,10 +2428,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
WMMAOpLowering, ScaledWMMAOpLowering, ExtPackedFp8OpLowering,
- ScaledExtPacked816OpLowering,
- ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
- PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
- GatherToLDSOpLowering, TransposeLoadOpLowering,
- AMDGPUPermlaneLowering>(converter, chipset);
+ ScaledExtPacked816OpLowering, ScaledExtPackedOpLowering,
+ PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
+ PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
+ TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 87bd1903290ae..fceded5c2c01f 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -450,7 +450,7 @@ LogicalResult ScaledWMMAOp::verify() {
auto sourceAType = cast<VectorType>(getSourceA().getType());
auto sourceBType = cast<VectorType>(getSourceB().getType());
auto destType = cast<VectorType>(getDestC().getType());
-
+
// Validate output type is F32
if (!destType.getElementType().isF32())
return emitOpError("destination must have f32 element type");
@@ -459,10 +459,10 @@ LogicalResult ScaledWMMAOp::verify() {
Type aElemType = sourceAType.getElementType();
Type bElemType = sourceBType.getElementType();
- bool aIsSmallFloat = aElemType.isFloat(4) || aElemType.isFloat(6) ||
- aElemType.isFloat(8);
- bool bIsSmallFloat = bElemType.isFloat(4) || bElemType.isFloat(6) ||
- bElemType.isFloat(8);
+ bool aIsSmallFloat =
+ aElemType.isFloat(4) || aElemType.isFloat(6) || aElemType.isFloat(8);
+ bool bIsSmallFloat =
+ bElemType.isFloat(4) || bElemType.isFloat(6) || bElemType.isFloat(8);
if (!aIsSmallFloat || !bIsSmallFloat)
return emitOpError("source operands must have small float element types "
@@ -479,7 +479,7 @@ LogicalResult ScaledWMMAOp::verify() {
int64_t aLen = sourceAType.getNumElements();
int64_t bLen = sourceBType.getNumElements();
int64_t expectedOutLen = (m == 16) ? 4 : 8;
-
+
if (destType.getNumElements() != expectedOutLen)
return emitOpError("expected output vector of length " +
Twine(expectedOutLen) + " but got " +
More information about the Mlir-commits
mailing list