[Mlir-commits] [mlir] [mlir][AMDGPU] Add scaled wmma ops for gfx1250 (PR #169854)
Justin Rosner
llvmlistbot at llvm.org
Fri Dec 5 07:45:58 PST 2025
https://github.com/justinrosner updated https://github.com/llvm/llvm-project/pull/169854
>From 0291d872fc1e595eb64a7ac167b78e31db0b022f Mon Sep 17 00:00:00 2001
From: Justin Rosner <justin.rosner at amd.com>
Date: Thu, 27 Nov 2025 19:25:56 +0000
Subject: [PATCH 1/2] Add scaled WMMA to AMDGPU
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 75 +++++++
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 193 +++++++++++++++---
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 92 +++++++++
.../AMDGPUToROCDL/wmma-gfx1250.mlir | 112 ++++++++++
mlir/test/Dialect/AMDGPU/ops.mlir | 21 ++
5 files changed, 467 insertions(+), 26 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 6ac84c646e3ae..63ce55dda8e98 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -959,6 +959,15 @@ def MFMAOutTypes : AnyTypeOf<[F64,
def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN]>,
VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
+
+// scaled_wmma
+def ScaledWMMAInTypes
+ : AnyTypeOf<[VectorOfLengthAndType<[64], [F8E5M2, F8E4M3FN]>,
+ VectorOfLengthAndType<[64], [F6E2M3FN, F6E3M2FN]>,
+ VectorOfLengthAndType<[64, 128], [F4E2M1FN]>]>;
+
+def ScaledWMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[8, 16], [F32]>]>;
+
// wmma
def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[2], [F32]>,
VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
@@ -1226,6 +1235,72 @@ def AMDGPU_ScaledMFMAOp :
let hasCanonicalizer = 1;
}
+def AMDGPU_ScaledWMMAOp
+ : AMDGPU_Op<"scaled_wmma", [AllTypesMatch<["destC", "destD"]>, Pure]>,
+ Arguments<(ins ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$m,
+ ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n,
+ ConfinedAttr<I32Attr, [IntIsOneOf<[128]>]>:$k,
+ ScaledWMMAInTypes:$sourceA, ScaledWMMAInTypes:$sourceB,
+ ScaledWMMAOutTypes:$destC,
+ VectorOfLengthAndType<[4, 8], [F8E8M0FNU, F8E4M3FN]>:$scaleA,
+ ConfinedAttr<I32Attr, [IntIsOneOf<[0, 1]>]>:$a_first_scale_lane,
+ VectorOfLengthAndType<[4, 8], [F8E8M0FNU, F8E4M3FN]>:$scaleB,
+ ConfinedAttr<I32Attr, [IntIsOneOf<[0, 1]>]>:$b_first_scale_lane)>,
+ Results<(outs ScaledWMMAOutTypes:$destD)> {
+ // TODO: E5M3FNU scales are supported, but there is not yet MLIR support for
+ // this datatype. Once we have support for that, update the scaleA and scaleB
+ // types here.
+ let summary = "MLIR wrapper for scaled wmma instructions";
+ let description = [{
+ The `amdgpu.scaled_wmma` op is an MLIR wrapper around intrinsics for scaled
+ `wmma` instructions. These instructions perform matrix multiplication with
+ per-block scaling of inputs, supporting fp4, fp6, and fp8 data formats.
+
+ The scale instructions support a block size of 16 or 32 and two tile sizes:
+ - 16x16x128 with mixed f8/f6/f4 formats (output: vector<4xf32>)
+ - 32x16x128 with f4 format only (output: vector<8xf32>)
+
+ Scale parameters (`scaleA`, `scaleB`) are small vectors of f8 scale values
+ (either f8E8M0FNU, or f8E4M3FN) that are packed into i32/i64 values during
+ lowering. The index attributes (`a_first_scale_lane`, `b_first_scale_lane`) select which register
+ lanes provide scale values:
+ - Block size 32: For tile size 16x16x128, each matrix gets 64 scales stored in half
+ a VGPR, with `a_first_scale_lane`/`b_first_scale_lane` selecting lanes 0-15 (index=0) or
+ 16-31 (index=1). For a tile size of 32x16x128, matrix A gets 128 scales in
+ a full VGPR (`a_first_scale_lane` is unused), while matrix B gets 64 scales in
+ half a VGPR.
+
+ - Block size 16: For a tile size of 16x16x128, each matrix gets
+ 128 scales stored in half of two VGPRs, with `a_first_scale_lane`/`b_first_scale_lane`
+ selecting lanes 0-15 (index=0) or 16-31 (index=1) for each of the VGPRs.
+ For 32x16x128, matrix A gets 256 scales in two VGPRs (`a_first_scale_lane` is unused),
+ while matrix B gets 128 scales stored in half of two VGPRs.
+
+ Example:
+ ```mlir
+ // 16x16x128: fp8 inputs
+ %0 = amdgpu.scaled_wmma 16x16x128 (%scaleVecA * %matA) * (%scaleVecB * %matB) + %matC
+ {a_first_scale_lane = 0 : i32, b_first_scale_lane = 0 : i32}
+ : vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>,
+ vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<8xf32>
+
+ // 32x16x128: fp4 inputs with different scale indices
+ %1 = amdgpu.scaled_wmma 32x16x128 (%scaleVecD * %matD) * (%scaleVecE * %matE) + %matF
+ {a_first_scale_lane = 0 : i32, b_first_scale_lane = 1 : i32}
+ : vector<8xf8E4M3FN>, vector<128xf4E2M1FN>,
+ vector<8xf8E4M3FN>, vector<64xf4E2M1FN>, vector<16xf32>
+ ```
+ }];
+ let assemblyFormat = [{
+ custom<MNKDimensionList>($m, $n, $k) ` `
+ `(` $scaleA `*` $sourceA `)` `*`
+ `(` $scaleB `*` $sourceB `)` `+` $destC
+ attr-dict
+ `:` type($scaleA) `,` type($sourceA) `,` type($scaleB) `,` type($sourceB) `,` type($destC)
+ }];
+ let hasVerifier = 1;
+}
+
def AMDGPU_MakeDmaBaseOp :
AMDGPU_Op<"make_dma_base", [Pure, AttrSizedOperandSegments]>,
Arguments<(ins Arg<AnyMemRef>:$global,
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index a85973c2493ee..9018cd40b8682 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -612,8 +612,8 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
} // namespace
-/// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
-/// and LLVM AMDGPU intrinsics convention.
+/// Pack small float vector operands (fp4/fp6/fp8/bf16) into the format
+/// expected by scaled matrix multiply intrinsics (MFMA/WMMA).
///
/// Specifically:
/// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
@@ -627,9 +627,9 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
/// Note that the type of `input` has already been LLVM type converted:
/// therefore 8-bit and smaller floats are represented as their corresponding
/// `iN` integers.
-static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
- Location loc, Value input,
- bool allowBf16 = true) {
+static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter,
+ Location loc, Value input,
+ bool allowBf16 = true) {
Type inputType = input.getType();
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
if (vectorType.getElementType().isBF16() && !allowBf16)
@@ -653,23 +653,59 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
return input;
}
-/// Converts the scaled MFMA operands, `scalesA` and `scalesB`, from MLIR AMDGPU
-/// dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
+/// Converts the scaled MFMA/WMMA operands, `scalesA` and `scalesB`, from MLIR
+/// AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
///
/// Specifically:
/// 1. If `input` is a i8 value, zero extend it to i32
-/// 2. If `input` is a vector of length 4 and type i8, cast it to i32
+/// 2. If `input` is a vector of length 4 or 8 and type i8, cast it to i32
///
/// Note that the type of `input` has already been LLVM type converted:
/// therefore 8-bit and smaller floats are represented as their corresponding
/// `iN` integers.
-static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
- Location loc, Value input) {
+static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc,
+ Value input) {
Type inputType = input.getType();
- Type outputType = rewriter.getI32Type();
+
+ // Handle scalar i8: zero extend to i32.
if (auto intType = dyn_cast<IntegerType>(inputType))
- return LLVM::ZExtOp::create(rewriter, loc, outputType, input);
- return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
+ return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI32Type(), input);
+
+ // Handle vector<4xi8> -> i32 or vector<8xi8> -> i64.
+ if (auto vectorType = dyn_cast<VectorType>(inputType)) {
+ int64_t numElements = vectorType.getNumElements();
+ assert((numElements == 4 || numElements == 8) &&
+ "scale operand must be a vector of length 4 or 8");
+ IntegerType outputType =
+ (numElements == 4) ? rewriter.getI32Type() : rewriter.getI64Type();
+ return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
+ }
+
+ llvm_unreachable("unexpected input type for scale operand");
+}
+
+/// Maps f8 scale element types to WMMA scale format codes.
+static std::optional<uint32_t> getWmmaScaleFormat(Type elemType) {
+ return TypeSwitch<Type, std::optional<uint32_t>>(elemType)
+ .Case([](Float8E8M0FNUType) { return 0; })
+ .Case([](Float8E4M3FNType) { return 2; })
+ .Default(std::nullopt);
+}
+
+/// Determines the ROCDL intrinsic name for scaled WMMA based on dimensions
+/// and scale block size (16 or 32).
+static std::optional<StringRef>
+getScaledWmmaIntrinsicName(int64_t m, int64_t n, int64_t k, bool isScale16) {
+ if (m == 16 && n == 16 && k == 128)
+ return isScale16
+ ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
+ : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
+
+ if (m == 32 && n == 16 && k == 128)
+ return isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
+ : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
+
+ return std::nullopt;
}
/// Push an input operand. If it is a float type, nothing to do. If it is
@@ -918,7 +954,7 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
return std::nullopt;
}
-static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {
+static std::optional<uint32_t> smallFloatTypeToFormatCode(Type mlirElemType) {
return llvm::TypeSwitch<Type, std::optional<uint32_t>>(mlirElemType)
.Case([](Float8E4M3FNType) { return 0u; })
.Case([](Float8E5M2Type) { return 1u; })
@@ -947,8 +983,8 @@ mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
if (!isa<Float32Type>(destType))
return std::nullopt;
- std::optional<uint32_t> aTypeCode = mfmaTypeSelectCode(aType);
- std::optional<uint32_t> bTypeCode = mfmaTypeSelectCode(bType);
+ std::optional<uint32_t> aTypeCode = smallFloatTypeToFormatCode(aType);
+ std::optional<uint32_t> bTypeCode = smallFloatTypeToFormatCode(bType);
if (!aTypeCode || !bTypeCode)
return std::nullopt;
@@ -1212,9 +1248,9 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
}();
OperationState loweredOp(loc, intrinsicName);
loweredOp.addTypes(intrinsicOutType);
- loweredOp.addOperands({convertMFMAVectorOperand(
+ loweredOp.addOperands({packSmallFloatVectorOperand(
rewriter, loc, adaptor.getSourceA(), allowBf16),
- convertMFMAVectorOperand(
+ packSmallFloatVectorOperand(
rewriter, loc, adaptor.getSourceB(), allowBf16),
adaptor.getDestC()});
if (isScaled) {
@@ -1261,8 +1297,8 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
OperationState loweredOp(loc, intrinsicName);
loweredOp.addTypes(intrinsicOutType);
loweredOp.addOperands(
- {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
- convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
+ {packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA()),
+ packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB()),
adaptor.getDestC()});
Value scalesIdxA =
createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
@@ -1273,10 +1309,10 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
createI32Constant(rewriter, loc, bTypeCode),
/*scales idx A=*/scalesIdxA,
/*scales A*/
- castMFMAScaleOperand(rewriter, loc, adaptor.getScalesA()),
+ castScaleOperand(rewriter, loc, adaptor.getScalesA()),
/*scales idx B=*/scalesIdxB,
/*scales B*/
- castMFMAScaleOperand(rewriter, loc, adaptor.getScalesB())});
+ castScaleOperand(rewriter, loc, adaptor.getScalesB())});
Value lowered = rewriter.create(loweredOp)->getResult(0);
rewriter.replaceOp(op, lowered);
return success();
@@ -1363,6 +1399,110 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
}
};
+struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
+ ScaledWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto outType =
+ typeConverter->convertType<VectorType>(op.getDestD().getType());
+ if (!outType)
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+ if (chipset < Chipset(12, 5, 0))
+ return op->emitOpError("WMMA scale only supported on gfx1250+");
+
+ int64_t m = op.getM();
+ int64_t n = op.getN();
+ int64_t k = op.getK();
+
+ Type aElemType = getElementTypeOrSelf(op.getSourceA().getType());
+ Type bElemType = getElementTypeOrSelf(op.getSourceB().getType());
+
+ std::optional<uint32_t> aFmtCode = smallFloatTypeToFormatCode(aElemType);
+ std::optional<uint32_t> bFmtCode = smallFloatTypeToFormatCode(bElemType);
+
+ if (!aFmtCode || !bFmtCode)
+ return op.emitOpError("unsupported element types for scaled_wmma");
+
+ // Get scale vector types and determine variant (scale vs scale16).
+ auto scaleAVecType = cast<VectorType>(op.getScaleA().getType());
+ auto scaleBVecType = cast<VectorType>(op.getScaleB().getType());
+
+ if (scaleAVecType.getNumElements() != scaleBVecType.getNumElements())
+ return op.emitOpError("scaleA and scaleB must have equal vector length");
+
+ // Extract scale format from element types.
+ Type scaleAElemType = scaleAVecType.getElementType();
+ Type scaleBElemType = scaleBVecType.getElementType();
+
+ std::optional<uint32_t> scaleAFmt = getWmmaScaleFormat(scaleAElemType);
+ std::optional<uint32_t> scaleBFmt = getWmmaScaleFormat(scaleBElemType);
+
+ if (!scaleAFmt || !scaleBFmt)
+ return op.emitOpError("unsupported scale element types");
+
+ // Determine which intrinsic to use based on dimensions.
+ bool isScale16 = (scaleAVecType.getNumElements() == 8);
+ std::optional<StringRef> intrinsicName =
+ getScaledWmmaIntrinsicName(m, n, k, isScale16);
+ if (!intrinsicName)
+ return op.emitOpError("unsupported scaled_wmma dimensions: ")
+ << m << "x" << n << "x" << k;
+
+ SmallVector<NamedAttribute, 8> attrs;
+
+ // The f4 variant does not have fmtA and fmtB attributes.
+ bool is32x16 = (m == 32 && n == 16 && k == 128);
+ if (!is32x16) {
+ attrs.emplace_back("fmtA", rewriter.getI32IntegerAttr(*aFmtCode));
+ attrs.emplace_back("fmtB", rewriter.getI32IntegerAttr(*bFmtCode));
+ }
+
+ // modC uses default value of 0.
+ attrs.emplace_back("modC", rewriter.getI16IntegerAttr(0));
+
+ // Scale attributes.
+ attrs.emplace_back("scaleAType",
+ rewriter.getI32IntegerAttr(op.getAFirstScaleLane()));
+ attrs.emplace_back("fmtScaleA", rewriter.getI32IntegerAttr(*scaleAFmt));
+ attrs.emplace_back("scaleBType",
+ rewriter.getI32IntegerAttr(op.getBFirstScaleLane()));
+ attrs.emplace_back("fmtScaleB", rewriter.getI32IntegerAttr(*scaleBFmt));
+
+ // Reuse flags use default value of false.
+ attrs.emplace_back("reuseA", rewriter.getBoolAttr(false));
+ attrs.emplace_back("reuseB", rewriter.getBoolAttr(false));
+
+ // Convert typed float vectors to packed format.
+ Value sourceA =
+ packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA());
+ Value sourceB =
+ packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB());
+
+ // Pack scale vectors into i32/i64.
+ Value packedScaleA = castScaleOperand(rewriter, loc, adaptor.getScaleA());
+ Value packedScaleB = castScaleOperand(rewriter, loc, adaptor.getScaleB());
+
+ // Create the intrinsic call.
+ OperationState loweredOp(loc, *intrinsicName);
+ loweredOp.addTypes(outType);
+ loweredOp.addOperands(
+ {sourceA, sourceB, adaptor.getDestC(), packedScaleA, packedScaleB});
+ loweredOp.addAttributes(attrs);
+
+ Operation *lowered = rewriter.create(loweredOp);
+ rewriter.replaceOp(op, lowered->getResults());
+
+ return success();
+ }
+};
+
struct TransposeLoadOpLowering
: public ConvertOpToLLVMPattern<TransposeLoadOp> {
TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
@@ -2408,10 +2548,11 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
ROCDL::RawPtrBufferAtomicCmpSwap>,
AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
- WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering,
- ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
- PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
- GatherToLDSOpLowering, TransposeLoadOpLowering, AMDGPUPermlaneLowering,
+ WMMAOpLowering, ScaledWMMAOpLowering, ExtPackedFp8OpLowering,
+ ScaledExtPackedMatrixOpLowering, ScaledExtPackedOpLowering,
+ PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
+ PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
+ TransposeLoadOpLowering, AMDGPUPermlaneLowering,
AMDGPUMakeDmaBaseLowering>(converter, chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index f78eca621da52..4dea96c5173f6 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -442,6 +442,98 @@ LogicalResult WMMAOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// ScaledWMMAOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ScaledWMMAOp::verify() {
+ // Helper functions for type classification.
+ auto isF8 = llvm::IsaPred<Float8E4M3FNType, Float8E5M2Type>;
+ auto isF6 = llvm::IsaPred<Float6E2M3FNType, Float6E3M2FNType>;
+ auto isF4 = llvm::IsaPred<Float4E2M1FNType>;
+ auto isScaleF8 = llvm::IsaPred<Float8E8M0FNUType, Float8E4M3FNType>;
+ auto isE8M0 = llvm::IsaPred<Float8E8M0FNUType>;
+ auto isE4M3 = llvm::IsaPred<Float8E4M3FNType>;
+
+ auto sourceAType = cast<VectorType>(getSourceA().getType());
+ auto sourceBType = cast<VectorType>(getSourceB().getType());
+ auto destType = cast<VectorType>(getDestC().getType());
+
+ // Validate source element types are small floats (fp4/fp6/fp8).
+ Type aElemType = sourceAType.getElementType();
+ Type bElemType = sourceBType.getElementType();
+
+ // Validate vector lengths based on dimensions.
+ int64_t m = getM();
+ int64_t aLen = sourceAType.getNumElements();
+ int64_t bLen = sourceBType.getNumElements();
+ int64_t expectedOutLen = (m == 16) ? 8 : 16;
+
+ if (destType.getNumElements() != expectedOutLen)
+ return emitOpError("expected output vector of length ")
+ << expectedOutLen << " but got " << destType.getNumElements();
+
+ if (m == 16) {
+ // For 16×16×128: both A and B must be 64 elements.
+ if (aLen != 64)
+ return emitOpError(
+ "for 16x16x128, sourceA must have 64 elements but got ")
+ << aLen;
+ if (bLen != 64)
+ return emitOpError(
+ "for 16x16x128, sourceB must have 64 elements but got ")
+ << bLen;
+ } else { // m == 32
+ // For 32×16×128: only fp4 is supported, A is 128, B is 64.
+ if (!isF4(aElemType))
+ return emitOpError("32x16x128 only supports fp4 element types");
+
+ if (aLen != 128)
+ return emitOpError(
+ "for 32x16x128, sourceA must have 128 elements but got ")
+ << aLen;
+ if (bLen != 64)
+ return emitOpError(
+ "for 32x16x128, sourceB must have 64 elements but got ")
+ << bLen;
+ }
+
+ // Validate scale types and their compatibility with matrix element types.
+ auto scaleAType = cast<VectorType>(getScaleA().getType());
+ auto scaleBType = cast<VectorType>(getScaleB().getType());
+ Type scaleAElemType = scaleAType.getElementType();
+ Type scaleBElemType = scaleBType.getElementType();
+
+ // Validate scale element types are valid scale f8 types (E8M0FNU or E4M3FN).
+ if (!isScaleF8(scaleAElemType) || !isScaleF8(scaleBElemType))
+ return emitOpError(
+ "scale operands must have f8 element types (E8M0FNU or E4M3FN)");
+
+ // Any matrices A/B (fp8|fp6|fp4) with E8M0 scales for matrix A/B are valid.
+ if (isE8M0(scaleAElemType) && isE8M0(scaleBElemType))
+ return success();
+
+ // Matrix A (F8|F6) x Matrix B (F4) with Scale A (E8M0), Scale B (E5M2|E4M3).
+ if ((isF8(aElemType) || isF6(aElemType)) && isE8M0(scaleAElemType) &&
+ isF4(bElemType) && isE4M3(scaleBElemType))
+ return success();
+
+ // Matrix A (F4) x Matrix B (F8|F6) with Scale A (E5M2|E4M3), Scale B (E8M0).
+ if (isF4(aElemType) && isE4M3(scaleAElemType) &&
+ (isF8(bElemType) || isF6(bElemType)) && isE8M0(scaleBElemType))
+ return success();
+
+ // Matrix A (F4) x Matrix B (F4) with Scale A (E4M3), Scale B (E4M3).
+ if (isF4(aElemType) && isF4(bElemType) && isE4M3(scaleAElemType) &&
+ isE4M3(scaleBElemType))
+ return success();
+
+ // No valid combination matched.
+ return emitOpError("invalid combination of matrix and scale types: ")
+ << "sourceA=" << aElemType << ", scaleA=" << scaleAElemType
+ << ", sourceB=" << bElemType << ", scaleB=" << scaleBElemType;
+}
+
//===----------------------------------------------------------------------===//
// MFMAOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
index 37259f6ed06eb..67749b4b9cb77 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
@@ -89,6 +89,73 @@ func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
return
}
+// CHECK-LABEL: @wmma_scale_16x16x128_fp8
+func.func @wmma_scale_16x16x128_fp8(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf6E2M3FN>,
+ %arg2 : vector<8xf32>, %arg3 : vector<4xf8E8M0FNU>) {
+ // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf32>, i32, i32) -> vector<8xf32>
+ %0 = amdgpu.scaled_wmma 16x16x128 (%arg3 * %arg0) * (%arg3 * %arg0) + %arg2 {a_first_scale_lane = 0 : i32, b_first_scale_lane = 0 : i32} : vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<8xf32>
+
+ // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} {fmtA = 2 : i32, fmtB = 2 : i32, scaleAType = 1 : i32} : (vector<12xi32>, vector<12xi32>, vector<8xf32>, i32, i32) -> vector<8xf32>
+ %1 = amdgpu.scaled_wmma 16x16x128 (%arg3 * %arg1) * (%arg3 * %arg1) + %arg2 {a_first_scale_lane = 1 : i32, b_first_scale_lane = 0 : i32} : vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<8xf32>
+
+ func.return
+}
+
+// CHECK-LABEL: @wmma_scale_16x16x128_fp6
+func.func @wmma_scale_16x16x128_fp6(%arg0 : vector<64xf6E2M3FN>, %arg1 : vector<64xf6E3M2FN>,
+ %arg2 : vector<8xf32>, %arg3 : vector<4xf8E8M0FNU>) {
+ // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} {fmtA = 2 : i32, fmtB = 2 : i32} : (vector<12xi32>, vector<12xi32>, vector<8xf32>, i32, i32) -> vector<8xf32>
+ %0 = amdgpu.scaled_wmma 16x16x128 (%arg3 * %arg0) * (%arg3 * %arg0) + %arg2 {a_first_scale_lane = 0 : i32, b_first_scale_lane = 0 : i32} : vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<8xf32>
+
+ // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} {fmtA = 3 : i32, fmtB = 3 : i32} : (vector<12xi32>, vector<12xi32>, vector<8xf32>, i32, i32) -> vector<8xf32>
+ %1 = amdgpu.scaled_wmma 16x16x128 (%arg3 * %arg1) * (%arg3 * %arg1) + %arg2 {a_first_scale_lane = 0 : i32, b_first_scale_lane = 0 : i32} : vector<4xf8E8M0FNU>, vector<64xf6E3M2FN>, vector<4xf8E8M0FNU>, vector<64xf6E3M2FN>, vector<8xf32>
+
+ func.return
+}
+
+// CHECK-LABEL: @wmma_scale_16x16x128_mixed
+func.func @wmma_scale_16x16x128_mixed(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf6E2M3FN>,
+ %arg2 : vector<64xf4E2M1FN>, %arg3 : vector<8xf32>,
+ %arg4 : vector<4xf8E8M0FNU>, %arg5 : vector<4xf8E4M3FN>) {
+ // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg3, {{.*}}, {{.*}} {fmtB = 4 : i32, fmtScaleB = 2 : i32} : (vector<16xi32>, vector<8xi32>, vector<8xf32>, i32, i32) -> vector<8xf32>
+ %0 = amdgpu.scaled_wmma 16x16x128 (%arg4 * %arg0) * (%arg5 * %arg2) + %arg3 {a_first_scale_lane = 0 : i32, b_first_scale_lane = 0 : i32} : vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<4xf8E4M3FN>, vector<64xf4E2M1FN>, vector<8xf32>
+
+ // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg3, {{.*}}, {{.*}} {fmtA = 2 : i32, fmtB = 4 : i32, fmtScaleB = 2 : i32} : (vector<12xi32>, vector<8xi32>, vector<8xf32>, i32, i32) -> vector<8xf32>
+ %1 = amdgpu.scaled_wmma 16x16x128 (%arg4 * %arg1) * (%arg5 * %arg2) + %arg3 {a_first_scale_lane = 0 : i32, b_first_scale_lane = 0 : i32} : vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<4xf8E4M3FN>, vector<64xf4E2M1FN>, vector<8xf32>
+
+ func.return
+}
+
+// CHECK-LABEL: @wmma_scale16_16x16x128_fp8
+func.func @wmma_scale16_16x16x128_fp8(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf6E3M2FN>,
+ %arg2 : vector<8xf32>, %arg3 : vector<8xf8E8M0FNU>) {
+ // CHECK: rocdl.wmma.scale16.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf32>, i64, i64) -> vector<8xf32>
+ %0 = amdgpu.scaled_wmma 16x16x128 (%arg3 * %arg0) * (%arg3 * %arg0) + %arg2 {a_first_scale_lane = 0 : i32, b_first_scale_lane = 0 : i32} : vector<8xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<8xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<8xf32>
+
+ // CHECK: rocdl.wmma.scale16.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} {fmtA = 3 : i32, fmtB = 3 : i32, scaleAType = 1 : i32} : (vector<12xi32>, vector<12xi32>, vector<8xf32>, i64, i64) -> vector<8xf32>
+ %1 = amdgpu.scaled_wmma 16x16x128 (%arg3 * %arg1) * (%arg3 * %arg1) + %arg2 {a_first_scale_lane = 1 : i32, b_first_scale_lane = 0 : i32} : vector<8xf8E8M0FNU>, vector<64xf6E3M2FN>, vector<8xf8E8M0FNU>, vector<64xf6E3M2FN>, vector<8xf32>
+
+ func.return
+}
+
+// CHECK-LABEL: @wmma_scale_32x16x128_fp4
+func.func @wmma_scale_32x16x128_fp4(%arg0 : vector<128xf4E2M1FN>, %arg1 : vector<64xf4E2M1FN>,
+ %arg2 : vector<16xf32>, %arg3 : vector<4xf8E4M3FN>) {
+ // CHECK: rocdl.wmma.scale.f32.32x16x128.f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} {fmtScaleA = 2 : i32, fmtScaleB = 2 : i32} : (vector<16xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
+ %0 = amdgpu.scaled_wmma 32x16x128 (%arg3 * %arg0) * (%arg3 * %arg1) + %arg2 {a_first_scale_lane = 0 : i32, b_first_scale_lane = 0 : i32} : vector<4xf8E4M3FN>, vector<128xf4E2M1FN>, vector<4xf8E4M3FN>, vector<64xf4E2M1FN>, vector<16xf32>
+
+ func.return
+}
+
+// CHECK-LABEL: @wmma_scale16_32x16x128_fp4
+func.func @wmma_scale16_32x16x128_fp4(%arg0 : vector<128xf4E2M1FN>, %arg1 : vector<64xf4E2M1FN>,
+ %arg2 : vector<16xf32>, %arg3 : vector<8xf8E4M3FN>) {
+ // CHECK: rocdl.wmma.scale16.f32.32x16x128.f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} {fmtScaleA = 2 : i32, fmtScaleB = 2 : i32} : (vector<16xi32>, vector<8xi32>, vector<16xf32>, i64, i64) -> vector<16xf32>
+ %0 = amdgpu.scaled_wmma 32x16x128 (%arg3 * %arg0) * (%arg3 * %arg1) + %arg2 {a_first_scale_lane = 0 : i32, b_first_scale_lane = 0 : i32} : vector<8xf8E4M3FN>, vector<128xf4E2M1FN>, vector<8xf8E4M3FN>, vector<64xf4E2M1FN>, vector<16xf32>
+
+ func.return
+}
+
// -----
func.func @wmma_unsupported_k(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) {
@@ -97,3 +164,48 @@ func.func @wmma_unsupported_k(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) {
amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<8xf16>, vector<8xf16>, vector<8xf32>
return
}
+
+// -----
+
+func.func @scaled_wmma_wrong_output_length(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<16xf32>,
+ %arg2 : vector<4xf8E8M0FNU>) {
+ // expected-error at below {{'amdgpu.scaled_wmma' op expected output vector of length 8 but got 16}}
+ %0 = amdgpu.scaled_wmma 16x16x128 (%arg2 * %arg0) * (%arg2 * %arg0) + %arg1 {a_first_scale_lane = 0 : i32, b_first_scale_lane = 0 : i32} : vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<16xf32>
+ return
+}
+
+func.func @scaled_wmma_16x16_wrong_sourceA_length(%arg0 : vector<128xf4E2M1FN>, %arg1 : vector<64xf4E2M1FN>,
+ %arg2 : vector<8xf32>, %arg3 : vector<4xf8E8M0FNU>) {
+ // expected-error at below {{'amdgpu.scaled_wmma' op for 16x16x128, sourceA must have 64 elements but got 128}}
+ %0 = amdgpu.scaled_wmma 16x16x128 (%arg3 * %arg0) * (%arg3 * %arg1) + %arg2 {a_first_scale_lane = 0 : i32, b_first_scale_lane = 0 : i32} : vector<4xf8E8M0FNU>, vector<128xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<64xf4E2M1FN>, vector<8xf32>
+ return
+}
+
+func.func @scaled_wmma_16x16_wrong_sourceB_length(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<128xf4E2M1FN>,
+ %arg2 : vector<8xf32>, %arg3 : vector<4xf8E8M0FNU>) {
+ // expected-error at below {{'amdgpu.scaled_wmma' op for 16x16x128, sourceB must have 64 elements but got 128}}
+ %0 = amdgpu.scaled_wmma 16x16x128 (%arg3 * %arg0) * (%arg3 * %arg1) + %arg2 {a_first_scale_lane = 0 : i32, b_first_scale_lane = 0 : i32} : vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<4xf8E8M0FNU>, vector<128xf4E2M1FN>, vector<8xf32>
+ return
+}
+
+func.func @scaled_wmma_32x16_wrong_sourceA_length(%arg0 : vector<64xf4E2M1FN>, %arg1 : vector<64xf4E2M1FN>,
+ %arg2 : vector<16xf32>, %arg3 : vector<4xf8E4M3FN>) {
+ // expected-error at below {{'amdgpu.scaled_wmma' op for 32x16x128, sourceA must have 128 elements but got 64}}
+ %0 = amdgpu.scaled_wmma 32x16x128 (%arg3 * %arg0) * (%arg3 * %arg1) + %arg2 {a_first_scale_lane = 0 : i32, b_first_scale_lane = 0 : i32} : vector<4xf8E4M3FN>, vector<64xf4E2M1FN>, vector<4xf8E4M3FN>, vector<64xf4E2M1FN>, vector<16xf32>
+ return
+}
+
+func.func @scaled_wmma_32x16_wrong_sourceB_length(%arg0 : vector<128xf4E2M1FN>, %arg1 : vector<128xf4E2M1FN>,
+ %arg2 : vector<16xf32>, %arg3 : vector<4xf8E4M3FN>) {
+ // expected-error at below {{'amdgpu.scaled_wmma' op for 32x16x128, sourceB must have 64 elements but got 128}}
+ %0 = amdgpu.scaled_wmma 32x16x128 (%arg3 * %arg0) * (%arg3 * %arg1) + %arg2 {a_first_scale_lane = 0 : i32, b_first_scale_lane = 0 : i32} : vector<4xf8E4M3FN>, vector<128xf4E2M1FN>, vector<4xf8E4M3FN>, vector<128xf4E2M1FN>, vector<16xf32>
+ return
+}
+
+func.func @scaled_wmma_invalid_type_combination(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf6E2M3FN>,
+ %arg2 : vector<8xf32>, %arg3 : vector<4xf8E8M0FNU>,
+ %arg4 : vector<4xf8E4M3FN>) {
+ // expected-error at below {{'amdgpu.scaled_wmma' op invalid combination of matrix and scale types}}
+ %0 = amdgpu.scaled_wmma 16x16x128 (%arg3 * %arg0) * (%arg4 * %arg1) + %arg2 {a_first_scale_lane = 0 : i32, b_first_scale_lane = 0 : i32} : vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<4xf8E4M3FN>, vector<64xf6E2M3FN>, vector<8xf32>
+ return
+}
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 0eccd0a7430bc..1c982561b6c8d 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -742,7 +742,28 @@ func.func @make_dma_descriptor(%base: !amdgpu.tdm_base<i32>, %barrier: memref<8x
// CHECK-SAME: iterate %[[IDX]], %[[IDX]], %[[IDX]]
iterate %idx, %idx, %idx
: !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
+ func.return
+}
+// CHECK-LABEL: func @wmma_scale
+func.func @wmma_scale(%fp8_src: vector<64xf8E4M3FN>, %fp6_alt_src: vector<64xf6E3M2FN>,
+ %fp6_src: vector<64xf6E2M3FN>, %fp4_src_a: vector<128xf4E2M1FN>,
+ %fp4_src_b: vector<64xf4E2M1FN>,
+ %dst0: vector<8xf32>, %dst1: vector<16xf32>,
+ %scale_vec4: vector<4xf8E8M0FNU>, %scale_vec8: vector<8xf8E8M0FNU>,
+ %scale_vec4_e4m3: vector<4xf8E4M3FN>) {
+ // CHECK: amdgpu.scaled_wmma 16x16x128 ({{.*}} * {{.*}}) * ({{.*}} * {{.*}}) + {{.*}} {a_first_scale_lane = 0 : i32, b_first_scale_lane = 0 : i32} : vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<8xf32>
+ %0 = amdgpu.scaled_wmma 16x16x128 (%scale_vec4 * %fp8_src) * (%scale_vec4 * %fp8_src) + %dst0 {a_first_scale_lane = 0 : i32, b_first_scale_lane = 0 : i32} : vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<8xf32>
+ // CHECK: amdgpu.scaled_wmma 16x16x128 ({{.*}} * {{.*}}) * ({{.*}} * {{.*}}) + {{.*}} {a_first_scale_lane = 0 : i32, b_first_scale_lane = 0 : i32} : vector<4xf8E8M0FNU>, vector<64xf6E3M2FN>, vector<4xf8E8M0FNU>, vector<64xf6E3M2FN>, vector<8xf32>
+ %1 = amdgpu.scaled_wmma 16x16x128 (%scale_vec4 * %fp6_alt_src) * (%scale_vec4 * %fp6_alt_src) + %dst0 {a_first_scale_lane = 0 : i32, b_first_scale_lane = 0 : i32} : vector<4xf8E8M0FNU>, vector<64xf6E3M2FN>, vector<4xf8E8M0FNU>, vector<64xf6E3M2FN>, vector<8xf32>
+ // CHECK: amdgpu.scaled_wmma 16x16x128 ({{.*}} * {{.*}}) * ({{.*}} * {{.*}}) + {{.*}} {a_first_scale_lane = 0 : i32, b_first_scale_lane = 0 : i32} : vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<8xf32>
+ %2 = amdgpu.scaled_wmma 16x16x128 (%scale_vec4 * %fp6_src) * (%scale_vec4 * %fp6_src) + %dst0 {a_first_scale_lane = 0 : i32, b_first_scale_lane = 0 : i32} : vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<8xf32>
+ // CHECK: amdgpu.scaled_wmma 16x16x128 ({{.*}} * {{.*}}) * ({{.*}} * {{.*}}) + {{.*}} {a_first_scale_lane = 0 : i32, b_first_scale_lane = 0 : i32} : vector<4xf8E4M3FN>, vector<64xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<8xf32>
+ %3 = amdgpu.scaled_wmma 16x16x128 (%scale_vec4_e4m3 * %fp4_src_b) * (%scale_vec4 * %fp6_src) + %dst0 {a_first_scale_lane = 0 : i32, b_first_scale_lane = 0 : i32} : vector<4xf8E4M3FN>, vector<64xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<8xf32>
+ // CHECK: amdgpu.scaled_wmma 16x16x128 ({{.*}} * {{.*}}) * ({{.*}} * {{.*}}) + {{.*}} {a_first_scale_lane = 0 : i32, b_first_scale_lane = 0 : i32} : vector<8xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<8xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<8xf32>
+ %4 = amdgpu.scaled_wmma 16x16x128 (%scale_vec8 * %fp8_src) * (%scale_vec8 * %fp8_src) + %dst0 {a_first_scale_lane = 0 : i32, b_first_scale_lane = 0 : i32} : vector<8xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<8xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<8xf32>
+ // CHECK: amdgpu.scaled_wmma 32x16x128 ({{.*}} * {{.*}}) * ({{.*}} * {{.*}}) + {{.*}} {a_first_scale_lane = 0 : i32, b_first_scale_lane = 0 : i32} : vector<4xf8E4M3FN>, vector<128xf4E2M1FN>, vector<4xf8E4M3FN>, vector<64xf4E2M1FN>, vector<16xf32>
+ %5 = amdgpu.scaled_wmma 32x16x128 (%scale_vec4_e4m3 * %fp4_src_a) * (%scale_vec4_e4m3 * %fp4_src_b) + %dst1 {a_first_scale_lane = 0 : i32, b_first_scale_lane = 0 : i32} : vector<4xf8E4M3FN>, vector<128xf4E2M1FN>, vector<4xf8E4M3FN>, vector<64xf4E2M1FN>, vector<16xf32>
func.return
}
>From e330cc17caae5d89b2303d354d990b7de7dd38c1 Mon Sep 17 00:00:00 2001
From: Justin Rosner <justin.rosner at amd.com>
Date: Fri, 5 Dec 2025 15:45:28 +0000
Subject: [PATCH 2/2] Minor copilot comments
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 4 ++--
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 6 +++---
2 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 63ce55dda8e98..5d27bd9c558d9 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -1257,8 +1257,8 @@ def AMDGPU_ScaledWMMAOp
per-block scaling of inputs, supporting fp4, fp6, and fp8 data formats.
The scale instructions support a block size of 16 or 32 and two tile sizes:
- - 16x16x128 with mixed f8/f6/f4 formats (output: vector<4xf32>)
- - 32x16x128 with f4 format only (output: vector<8xf32>)
+ - 16x16x128 with mixed f8/f6/f4 formats (output: vector<8xf32>)
+ - 32x16x128 with f4 format only (output: vector<16xf32>)
Scale parameters (`scaleA`, `scaleB`) are small vectors of f8 scale values
(either f8E8M0FNU, or f8E4M3FN) that are packed into i32/i64 values during
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 4dea96c5173f6..c1f7bca8ad69f 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -485,7 +485,7 @@ LogicalResult ScaledWMMAOp::verify() {
<< bLen;
} else { // m == 32
// For 32×16×128: only fp4 is supported, A is 128, B is 64.
- if (!isF4(aElemType))
+ if (!isF4(aElemType) && !isF4(bElemType))
return emitOpError("32x16x128 only supports fp4 element types");
if (aLen != 128)
@@ -513,12 +513,12 @@ LogicalResult ScaledWMMAOp::verify() {
if (isE8M0(scaleAElemType) && isE8M0(scaleBElemType))
return success();
- // Matrix A (F8|F6) x Matrix B (F4) with Scale A (E8M0), Scale B (E5M2|E4M3).
+ // Matrix A (F8|F6) x Matrix B (F4) with Scale A (E8M0), Scale B (E5M3|E4M3).
if ((isF8(aElemType) || isF6(aElemType)) && isE8M0(scaleAElemType) &&
isF4(bElemType) && isE4M3(scaleBElemType))
return success();
- // Matrix A (F4) x Matrix B (F8|F6) with Scale A (E5M2|E4M3), Scale B (E8M0).
+ // Matrix A (F4) x Matrix B (F8|F6) with Scale A (E5M3|E4M3), Scale B (E8M0).
if (isF4(aElemType) && isE4M3(scaleAElemType) &&
(isF8(bElemType) || isF6(bElemType)) && isE8M0(scaleBElemType))
return success();
More information about the Mlir-commits
mailing list