[Mlir-commits] [mlir] [mlir][amdgpu] implement amdgpu.sparse_mfma wrapper for smfmac instructions (PR #171968)
Eric Feng
llvmlistbot at llvm.org
Thu Dec 18 14:26:15 PST 2025
https://github.com/efric updated https://github.com/llvm/llvm-project/pull/171968
>From 73caf01662dba501bb1c887bdffa0eef6a3ba678 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Thu, 11 Dec 2025 22:18:17 -0800
Subject: [PATCH 01/10] implement amdgpu wrapper for smfmac
Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 81 ++++++++
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 176 +++++++++++++++++-
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 56 ++++++
.../AMDGPUToROCDL/sparse-mfma-gfx950.mlir | 53 ++++++
.../Conversion/AMDGPUToROCDL/sparse-mfma.mlir | 61 ++++++
5 files changed, 422 insertions(+), 5 deletions(-)
create mode 100644 mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma-gfx950.mlir
create mode 100644 mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 56160d3e8fe85..9b4947049c388 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -958,6 +958,27 @@ def MFMAOutTypes : AnyTypeOf<[F64,
VectorOfLengthAndType<[4, 16, 32], [F32]>,
VectorOfLengthAndType<[4, 16, 32], [I32]>,
VectorOfLengthAndType<[4], [F64]>]>;
+
+// sparse_mfma (smfmac)
+def SMFMACSparseInTypes : AnyTypeOf<[
+ VectorOfLengthAndType<[4, 8], [F16]>,
+ VectorOfLengthAndType<[4, 8], [BF16]>,
+ VectorOfLengthAndType<[8, 16], [I8]>,
+ VectorOfLengthAndType<[8, 16], [F8E4M3FN, F8E5M2]>
+]>;
+
+def SMFMACDenseInTypes : AnyTypeOf<[
+ VectorOfLengthAndType<[8, 16], [F16]>,
+ VectorOfLengthAndType<[8, 16], [BF16]>,
+ VectorOfLengthAndType<[16, 32], [I8]>,
+ VectorOfLengthAndType<[16, 32], [F8E4M3FN, F8E5M2]>
+]>;
+
+def SMFMACOutTypes : AnyTypeOf<[
+ VectorOfLengthAndType<[4, 16], [F32]>,
+ VectorOfLengthAndType<[4, 16], [I32]>
+]>;
+
// scaled_mfma
def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN]>,
VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
@@ -1097,6 +1118,66 @@ def AMDGPU_WMMAOp :
let hasVerifier = 1;
}
+def AMDGPU_SparseMFMAOp :
+ AMDGPU_Op<"sparse_mfma", [AllTypesMatch<["destC", "destD"]>,
+ Pure]>,
+ Arguments<(ins
+ ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$m,
+ ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$n,
+ ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32, 64, 128]>]>:$k,
+ SMFMACSparseInTypes:$sourceA,
+ SMFMACDenseInTypes:$sourceB,
+ SMFMACOutTypes:$destC,
+ I32:$sparseIdx,
+ DefaultValuedAttr<I32Attr, "0">:$cbsz,
+ DefaultValuedAttr<I32Attr, "0">:$abid)>,
+ Results<(outs SMFMACOutTypes: $destD)> {
+ let summary = "MLIR wrapper for CDNA sparse mfma (smfmac) instructions";
+ let description = [{
+ The `amdgpu.sparse_mfma` op is an MLIR wrapper around intrinsics for various
+ `smfmac` instructions in the AMDGPU architecture, which perform matrix
+ multiply-accumulate operations using 2:4 structured sparsity on matrix A
+ with dense matrices B, C, and D.
+
+ On gfx940, smfmac intrinsics support:
+ - M=N=16, K=32 and M=N=32, K=16 for f16 and bf16 sources
+ - M=N=16, K=64 and M=N=32, K=32 for i8 and fp8 sources
+
+ On gfx950, smfmac intrinsics additionally support:
+ - M=N=16, K=64 and M=N=32, K=32 for f16 and bf16 sources
+ - M=N=16, K=128 and M=N=32, K=64 for i8 and fp8 sources
+
+ The `sparseIdx` parameter (i32) contains packed indices identifying the
+ positions of non-zero elements in the 2:4 sparse matrix A. For 16-bit data,
+ this uses four groups of 8-bit indices; for 8-bit data, 2 groups of 16-bit
+ indices.
+
+ The `cbsz` and `abid` parameters are repurposed to select the index set.
+ If `cbsz == 0`, then `abid[1:0]` selects which index set to use.
+ If `cbsz != 0`, then the very first is selected.
+
+ Example:
+ ```mlir
+ %0 = amdgpu.sparse_mfma 16x16x32 %matA * %matB + %matC sparse(%idx)
+ : vector<4xf16>, vector<8xf16>, vector<4xf32>
+
+ %1 = amdgpu.sparse_mfma 16x16x64 %matA * %matB + %matC sparse(%idx)
+ : vector<8xf16>, vector<16xf16>, vector<4xf32>
+
+ %2 = amdgpu.sparse_mfma 16x16x128 %matA * %matB + %matC sparse(%idx)
+ { cbsz = 0 : i32, abid = 1 : i32 }
+ : vector<4xi32>, vector<8xi32>, vector<4xi32>
+ ```
+ }];
+ let assemblyFormat = [{
+ custom<MNKDimensionList>($m, $n, $k) $sourceA `*` $sourceB `+` $destC
+ `sparse` `(` $sparseIdx `)`
+ attr-dict
+ `:` type($sourceA) `,` type($sourceB) `,` type($destC)
+ }];
+ let hasVerifier = 1;
+}
+
def AMDGPU_GatherToLDSOp :
AMDGPU_Op<"gather_to_lds", [AttrSizedOperandSegments]>,
Arguments<(ins
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 4b1509392aa6f..855d0c9df4281 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -661,6 +661,30 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
return input;
}
+/// Converts sparse MFMA (smfmac) operands to the expected ROCDL types.
+static Value convertSparseMFMAVectorOperand(ConversionPatternRewriter &rewriter,
+ Location loc, Value input,
+ bool allowBf16 = true) {
+ Type inputType = input.getType();
+ if (auto vectorType = dyn_cast<VectorType>(inputType)) {
+ // bf16 -> i16 when not allowed (pre-gfx950)
+ if (vectorType.getElementType().isBF16() && !allowBf16)
+ return LLVM::BitcastOp::create(
+ rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
+ // i8/fp8 vectors -> vector<Nxi32>
+ if (isa<IntegerType>(vectorType.getElementType()) &&
+ vectorType.getElementTypeBitWidth() <= 8) {
+ int64_t numWords = llvm::divideCeil(
+ vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
+ 32);
+ return LLVM::BitcastOp::create(
+ rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
+ input);
+ }
+ }
+ return input;
+}
+
/// Converts the scaled MFMA operands, `scalesA` and `scalesB`, from MLIR AMDGPU
/// dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
///
@@ -1136,6 +1160,104 @@ static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType,
return std::nullopt;
}
+/// Returns the `rocdl` intrinsic corresponding to a SparseMFMA (smfmac)
+/// operation if one exists. This includes checking to ensure the intrinsic is
+/// supported on the architecture you are compiling for.
+static std::optional<StringRef> smfmacOpToIntrinsic(SparseMFMAOp op,
+ bool isGfx950) {
+ using fp8 = Float8E4M3FNType;
+ using bf8 = Float8E5M2Type;
+
+ uint32_t m = op.getM(), n = op.getN(), k = op.getK();
+ Type sourceAElem = getElementTypeOrSelf(op.getSourceA().getType());
+ Type sourceBElem = getElementTypeOrSelf(op.getSourceB().getType());
+ Type destElem = getElementTypeOrSelf(op.getDestC().getType());
+
+ if (m == 16 && n == 16 && k == 32) {
+ if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
+ return ROCDL::smfmac_f32_16x16x32_f16::getOperationName();
+ if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
+ return ROCDL::smfmac_f32_16x16x32_bf16::getOperationName();
+ }
+
+ if (m == 16 && n == 16 && k == 64) {
+ if (isGfx950) {
+ if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
+ return ROCDL::smfmac_f32_16x16x64_f16::getOperationName();
+ if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
+ return ROCDL::smfmac_f32_16x16x64_bf16::getOperationName();
+ }
+ if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
+ destElem.isInteger(32))
+ return ROCDL::smfmac_i32_16x16x64_i8::getOperationName();
+ if (isa<fp8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_16x16x64_fp8_fp8::getOperationName();
+ if (isa<fp8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_16x16x64_fp8_bf8::getOperationName();
+ if (isa<bf8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_16x16x64_bf8_fp8::getOperationName();
+ if (isa<bf8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_16x16x64_bf8_bf8::getOperationName();
+ }
+
+ if (m == 16 && n == 16 && k == 128 && isGfx950) {
+ if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
+ destElem.isInteger(32))
+ return ROCDL::smfmac_i32_16x16x128_i8::getOperationName();
+ if (isa<fp8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_16x16x128_fp8_fp8::getOperationName();
+ if (isa<fp8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_16x16x128_fp8_bf8::getOperationName();
+ if (isa<bf8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_16x16x128_bf8_fp8::getOperationName();
+ if (isa<bf8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_16x16x128_bf8_bf8::getOperationName();
+ }
+
+ if (m == 32 && n == 32 && k == 16) {
+ if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
+ return ROCDL::smfmac_f32_32x32x16_f16::getOperationName();
+ if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
+ return ROCDL::smfmac_f32_32x32x16_bf16::getOperationName();
+ }
+
+ if (m == 32 && n == 32 && k == 32) {
+ if (isGfx950) {
+ if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
+ return ROCDL::smfmac_f32_32x32x32_f16::getOperationName();
+ if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
+ return ROCDL::smfmac_f32_32x32x32_bf16::getOperationName();
+ }
+ if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
+ destElem.isInteger(32))
+ return ROCDL::smfmac_i32_32x32x32_i8::getOperationName();
+ if (isa<fp8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_32x32x32_fp8_fp8::getOperationName();
+ if (isa<fp8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_32x32x32_fp8_bf8::getOperationName();
+ if (isa<bf8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_32x32x32_bf8_fp8::getOperationName();
+ if (isa<bf8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_32x32x32_bf8_bf8::getOperationName();
+ }
+
+ if (m == 32 && n == 32 && k == 64 && isGfx950) {
+ if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
+ destElem.isInteger(32))
+ return ROCDL::smfmac_i32_32x32x64_i8::getOperationName();
+ if (isa<fp8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_32x32x64_fp8_fp8::getOperationName();
+ if (isa<fp8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_32x32x64_fp8_bf8::getOperationName();
+ if (isa<bf8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_32x32x64_bf8_fp8::getOperationName();
+ if (isa<bf8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_32x32x64_bf8_bf8::getOperationName();
+ }
+
+ return std::nullopt;
+}
+
/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
/// if one exists. This includes checking to ensure the intrinsic is supported
/// on the architecture you are compiling for.
@@ -1291,6 +1413,49 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
}
};
+struct SparseMFMAOpLowering : public ConvertOpToLLVMPattern<SparseMFMAOp> {
+ SparseMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<SparseMFMAOp>(converter), chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(SparseMFMAOp op, SparseMFMAOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto outType =
+ typeConverter->convertType<VectorType>(op.getDestC().getType());
+ if (!outType)
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+ // smfmac is supported on gfx942 and gfx950
+ if (chipset.majorVersion != 9 || chipset < kGfx942)
+ return op->emitOpError("sparse MFMA (smfmac) only supported on gfx942+");
+ bool isGfx950 = chipset >= kGfx950;
+
+ Value a = convertSparseMFMAVectorOperand(rewriter, loc,
+ adaptor.getSourceA(), isGfx950);
+ Value b = convertSparseMFMAVectorOperand(rewriter, loc,
+ adaptor.getSourceB(), isGfx950);
+ Value c = adaptor.getDestC();
+
+ std::optional<StringRef> maybeIntrinsic = smfmacOpToIntrinsic(op, isGfx950);
+
+ if (!maybeIntrinsic.has_value())
+ return op.emitOpError(
+ "no intrinsic matching sparse MFMA on the given chipset");
+
+ OperationState loweredOp(loc, maybeIntrinsic.value());
+ loweredOp.addTypes(outType);
+ loweredOp.addOperands({a, b, c, adaptor.getSparseIdx(),
+ createI32Constant(rewriter, loc, op.getCbsz()),
+ createI32Constant(rewriter, loc, op.getAbid())});
+ Value lowered = rewriter.create(loweredOp)->getResult(0);
+ rewriter.replaceOp(op, lowered);
+ return success();
+ }
+};
+
struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
@@ -2797,11 +2962,12 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
ROCDL::RawPtrBufferAtomicCmpSwap>,
AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
- SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
- WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering,
- ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
- PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
- GatherToLDSOpLowering, TransposeLoadOpLowering, AMDGPUPermlaneLowering,
+ SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
+ WMMAOpLowering, SparseMFMAOpLowering, ExtPackedFp8OpLowering,
+ ScaledExtPackedMatrixOpLowering, ScaledExtPackedOpLowering,
+ PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
+ PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
+ TransposeLoadOpLowering, AMDGPUPermlaneLowering,
AMDGPUMakeDmaBaseLowering, AMDGPUMakeDmaDescriptorLowering>(converter,
chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index b7a665b0f5367..2cc1aaa8e3b2d 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -522,6 +522,62 @@ LogicalResult MFMAOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// SparseMFMAOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult SparseMFMAOp::verify() {
+ constexpr uint32_t waveSize = 64;
+
+ auto sparseType = cast<VectorType>(getSourceA().getType());
+ auto denseType = cast<VectorType>(getSourceB().getType());
+ auto destType = cast<VectorType>(getDestC().getType());
+
+ Type sparseElem = sparseType.getElementType();
+ Type denseElem = denseType.getElementType();
+ int64_t sparseLen = sparseType.getNumElements();
+ int64_t denseLen = denseType.getNumElements();
+ int64_t destLen = destType.getNumElements();
+
+ if (denseLen != 2 * sparseLen)
+ return emitOpError("expected dense source operand to have exactly double "
+ "the number of elements of the sparse source operand");
+
+ // Check that source element types are compatible.
+ // For fp8/bf8 mixed operations, element types can differ (e.g., fp8 * bf8).
+ // For other types, element types must match exactly.
+ bool bothFloat8 = sparseElem.isFloat(8) && denseElem.isFloat(8);
+ if (!bothFloat8 && sparseElem != denseElem)
+ return emitOpError(
+ "expected source operands to have the same element type");
+
+ // When CBSZ == 0, ABID selects the index set within the sparse index VGPR.
+ // When CBSZ != 0, the first index set is always used (ABID ignored).
+ bool is8BitSource = sparseElem.isFloat(8) || sparseElem.isInteger(8);
+ if (getCbsz() == 0 && is8BitSource) {
+ // 8-bit source: ABID[0] selects one of two 16-bit index sets.
+ if (getAbid() > 1)
+ return emitOpError(
+ "ABID must be 0 or 1 for 8-bit source data when CBSZ is 0");
+ }
+ // 16-bit source: ABID[1:0] selects one of four 8-bit index sets (0-3 all
+ // valid).
+
+ int64_t expectedSourceElems = (getM() * getK()) / waveSize;
+ if (denseLen != expectedSourceElems)
+ return emitOpError("expected " + Twine(expectedSourceElems) +
+ " source values for this operation but got " +
+ Twine(denseLen));
+
+ int64_t expectedDestElems = (getM() * getN()) / waveSize;
+ if (destLen != expectedDestElems)
+ return emitOpError("expected " + Twine(expectedDestElems) +
+ " result values for this operation but got " +
+ Twine(destLen));
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// DPPOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma-gfx950.mlir b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma-gfx950.mlir
new file mode 100644
index 0000000000000..abe2565f7c41b
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma-gfx950.mlir
@@ -0,0 +1,53 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx950 -cse | FileCheck %s
+func.func @sparse_mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf16>,
+ %arg2 : vector<4xf32>, %arg3 : vector<16xf32>,
+ %arg4 : vector<8xbf16>, %arg5 : vector<16xbf16>,
+ %arg6 : vector<16xi8>, %arg7 : vector<32xi8>,
+ %arg8 : vector<4xi32>, %arg9 : vector<16xi32>,
+ %arg10 : vector<16xf8E4M3FN>, %arg11 : vector<16xf8E5M2>,
+ %arg12 : vector<32xf8E4M3FN>, %arg13 : vector<32xf8E5M2>,
+ %arg14 : i32) {
+ // CHECK: rocdl.smfmac.f32.16x16x64.f16{{.*}}: (vector<8xf16>, vector<16xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.sparse_mfma 16x16x64 %arg0 * %arg1 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf16>, vector<16xf16>, vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x64.bf16{{.*}}: (vector<8xbf16>, vector<16xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.sparse_mfma 16x16x64 %arg4 * %arg5 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xbf16>, vector<16xbf16>, vector<4xf32>
+
+ // CHECK: rocdl.smfmac.i32.16x16x128.i8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+ amdgpu.sparse_mfma 16x16x128 %arg6 * %arg7 + %arg8 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xi8>, vector<32xi8>, vector<4xi32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x128.fp8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.sparse_mfma 16x16x128 %arg10 * %arg12 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E4M3FN>, vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x128.bf8.bf8 {{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.sparse_mfma 16x16x128 %arg11 * %arg13 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E5M2>, vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x128.fp8.bf8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.sparse_mfma 16x16x128 %arg10 * %arg13 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E5M2>, vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x128.bf8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.sparse_mfma 16x16x128 %arg11 * %arg12 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E4M3FN>, vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x32.f16{{.*}}: (vector<8xf16>, vector<16xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.sparse_mfma 32x32x32 %arg0 * %arg1 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf16>, vector<16xf16>, vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x32.bf16{{.*}}: (vector<8xbf16>, vector<16xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.sparse_mfma 32x32x32 %arg4 * %arg5 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xbf16>, vector<16xbf16>, vector<16xf32>
+
+ // CHECK: rocdl.smfmac.i32.32x32x64.i8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
+ amdgpu.sparse_mfma 32x32x64 %arg6 * %arg7 + %arg9 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xi8>, vector<32xi8>, vector<16xi32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x64.fp8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.sparse_mfma 32x32x64 %arg10 * %arg12 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E4M3FN>, vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x64.bf8.bf8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.sparse_mfma 32x32x64 %arg11 * %arg13 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E5M2>, vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x64.fp8.bf8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.sparse_mfma 32x32x64 %arg10 * %arg13 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E5M2>, vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x64.bf8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.sparse_mfma 32x32x64 %arg11 * %arg12 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E4M3FN>, vector<16xf32>
+
+ func.return
+}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir
new file mode 100644
index 0000000000000..65a0cd3f1f87f
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir
@@ -0,0 +1,61 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx942 -cse | FileCheck %s
+func.func @sparse_mfma_to_rocdl(%arg0 : vector<4xf16>, %arg1 : vector<8xf16>,
+ %arg2 : vector<4xf32>, %arg3 : vector<16xf32>,
+ %arg4 : vector<4xbf16>, %arg5 : vector<8xbf16>,
+ %arg6 : vector<8xi8>, %arg7 : vector<16xi8>,
+ %arg8 : vector<4xi32>, %arg9 : vector<16xi32>,
+ %arg10 : vector<8xf8E4M3FN>, %arg11 : vector<8xf8E5M2>,
+ %arg12 : vector<16xf8E4M3FN>, %arg13 : vector<16xf8E5M2>,
+ %arg14 : i32) {
+ // CHECK: rocdl.smfmac.f32.16x16x32.f16{{.*}}: (vector<4xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.sparse_mfma 16x16x32 %arg0 * %arg1 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xf16>, vector<8xf16>, vector<4xf32>
+
+ // CHECK-NEXT: llvm.bitcast {{.*}} : vector<4xbf16> to vector<4xi16>
+ // CHECK-NEXT: llvm.bitcast {{.*}} : vector<8xbf16> to vector<8xi16>
+ // CHECK: rocdl.smfmac.f32.16x16x32.bf16 {{.*}}: (vector<4xi16>, vector<8xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.sparse_mfma 16x16x32 %arg4 * %arg5 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xbf16>, vector<8xbf16>, vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x16.f16{{.*}}: (vector<4xf16>, vector<8xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.sparse_mfma 32x32x16 %arg0 * %arg1 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xf16>, vector<8xf16>, vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x16.bf16 {{.*}}: (vector<4xi16>, vector<8xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.sparse_mfma 32x32x16 %arg4 * %arg5 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xbf16>, vector<8xbf16>, vector<16xf32>
+
+ // CHECK-NEXT: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32>
+ // CHECK-NEXT: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
+ // CHECK: rocdl.smfmac.i32.16x16x64.i8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+ amdgpu.sparse_mfma 16x16x64 %arg6 * %arg7 + %arg8 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xi8>, vector<16xi8>, vector<4xi32>
+
+ // CHECK-NEXT: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32>
+ // CHECK-NEXT: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
+ // CHECK: rocdl.smfmac.f32.16x16x64.fp8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.sparse_mfma 16x16x64 %arg10 * %arg12 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FN>, vector<16xf8E4M3FN>, vector<4xf32>
+
+ // CHECK-NEXT: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32>
+ // CHECK-NEXT: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
+ // CHECK: rocdl.smfmac.f32.16x16x64.bf8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.sparse_mfma 16x16x64 %arg11 * %arg13 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2>, vector<16xf8E5M2>, vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x64.fp8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.sparse_mfma 16x16x64 %arg10 * %arg13 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FN>, vector<16xf8E5M2>, vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x64.bf8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.sparse_mfma 16x16x64 %arg11 * %arg12 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2>, vector<16xf8E4M3FN>, vector<4xf32>
+
+ // CHECK: rocdl.smfmac.i32.32x32x32.i8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
+ amdgpu.sparse_mfma 32x32x32 %arg6 * %arg7 + %arg9 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xi8>, vector<16xi8>, vector<16xi32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x32.fp8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.sparse_mfma 32x32x32 %arg10 * %arg12 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FN>, vector<16xf8E4M3FN>, vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x32.bf8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.sparse_mfma 32x32x32 %arg11 * %arg13 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2>, vector<16xf8E5M2>, vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x32.fp8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.sparse_mfma 32x32x32 %arg10 * %arg13 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FN>, vector<16xf8E5M2>, vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x32.bf8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.sparse_mfma 32x32x32 %arg11 * %arg12 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2>, vector<16xf8E4M3FN>, vector<16xf32>
+
+ func.return
+}
>From e05394559965b8fdb6001313d651242683d049ed Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Thu, 11 Dec 2025 22:30:42 -0800
Subject: [PATCH 02/10] nits
Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 14 ++++++--------
1 file changed, 6 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 2cc1aaa8e3b2d..4231014c77982 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -554,14 +554,12 @@ LogicalResult SparseMFMAOp::verify() {
// When CBSZ == 0, ABID selects the index set within the sparse index VGPR.
// When CBSZ != 0, the first index set is always used (ABID ignored).
bool is8BitSource = sparseElem.isFloat(8) || sparseElem.isInteger(8);
- if (getCbsz() == 0 && is8BitSource) {
- // 8-bit source: ABID[0] selects one of two 16-bit index sets.
- if (getAbid() > 1)
- return emitOpError(
- "ABID must be 0 or 1 for 8-bit source data when CBSZ is 0");
- }
- // 16-bit source: ABID[1:0] selects one of four 8-bit index sets (0-3 all
- // valid).
+ // 8-bit source: ABID selects one of two 16-bit index sets.
+ if (getCbsz() == 0 && is8BitSource && getAbid() > 1)
+ return emitOpError("ABID must be 0 or 1 for 8-bit source data");
+ // 16-bit source: ABID selects one of four 8-bit index sets (0-3 all valid).
+ if (getCbsz() == 0 && !is8BitSource && getAbid() > 3)
+ return emitOpError("ABID must be between 0 and 3 for 16-bit source data");
int64_t expectedSourceElems = (getM() * getK()) / waveSize;
if (denseLen != expectedSourceElems)
>From f3d945681062873b967dc4c5711b495d85816d67 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Thu, 11 Dec 2025 22:36:58 -0800
Subject: [PATCH 03/10] nit
Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 9b4947049c388..b23343f5cdd1e 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -1139,7 +1139,7 @@ def AMDGPU_SparseMFMAOp :
multiply-accumulate operations using 2:4 structured sparsity on matrix A
with dense matrices B, C, and D.
- On gfx940, smfmac intrinsics support:
+ On gfx942, smfmac intrinsics support:
- M=N=16, K=32 and M=N=32, K=16 for f16 and bf16 sources
- M=N=16, K=64 and M=N=32, K=32 for i8 and fp8 sources
>From 248ad159f19e988c2a3eaebc9672a1ce97b9c307 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Thu, 11 Dec 2025 22:39:01 -0800
Subject: [PATCH 04/10] format thing
Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 46 +++++++++----------
1 file changed, 23 insertions(+), 23 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 855d0c9df4281..65a68d417b663 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -2947,28 +2947,28 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns,
Chipset chipset) {
populateAMDGPUTypeAndAttributeConversions(converter);
- patterns.add<
- FatRawBufferCastLowering,
- RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
- RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
- RawBufferOpLowering<RawBufferAtomicFaddOp,
- ROCDL::RawPtrBufferAtomicFaddOp>,
- RawBufferOpLowering<RawBufferAtomicFmaxOp,
- ROCDL::RawPtrBufferAtomicFmaxOp>,
- RawBufferOpLowering<RawBufferAtomicSmaxOp,
- ROCDL::RawPtrBufferAtomicSmaxOp>,
- RawBufferOpLowering<RawBufferAtomicUminOp,
- ROCDL::RawPtrBufferAtomicUminOp>,
- RawBufferOpLowering<RawBufferAtomicCmpswapOp,
- ROCDL::RawPtrBufferAtomicCmpSwap>,
- AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
- SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
- WMMAOpLowering, SparseMFMAOpLowering, ExtPackedFp8OpLowering,
- ScaledExtPackedMatrixOpLowering, ScaledExtPackedOpLowering,
- PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
- PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
- TransposeLoadOpLowering, AMDGPUPermlaneLowering,
- AMDGPUMakeDmaBaseLowering, AMDGPUMakeDmaDescriptorLowering>(converter,
- chipset);
+ patterns
+ .add<FatRawBufferCastLowering,
+ RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
+ RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
+ RawBufferOpLowering<RawBufferAtomicFaddOp,
+ ROCDL::RawPtrBufferAtomicFaddOp>,
+ RawBufferOpLowering<RawBufferAtomicFmaxOp,
+ ROCDL::RawPtrBufferAtomicFmaxOp>,
+ RawBufferOpLowering<RawBufferAtomicSmaxOp,
+ ROCDL::RawPtrBufferAtomicSmaxOp>,
+ RawBufferOpLowering<RawBufferAtomicUminOp,
+ ROCDL::RawPtrBufferAtomicUminOp>,
+ RawBufferOpLowering<RawBufferAtomicCmpswapOp,
+ ROCDL::RawPtrBufferAtomicCmpSwap>,
+ AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
+ SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
+ WMMAOpLowering, SparseMFMAOpLowering, ExtPackedFp8OpLowering,
+ ScaledExtPackedMatrixOpLowering, ScaledExtPackedOpLowering,
+ PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
+ PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
+ TransposeLoadOpLowering, AMDGPUPermlaneLowering,
+ AMDGPUMakeDmaBaseLowering, AMDGPUMakeDmaDescriptorLowering>(
+ converter, chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}
>From b18719d7811e411b37f6dddc7054b71950ab8518 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Thu, 11 Dec 2025 22:59:09 -0800
Subject: [PATCH 05/10] nit
Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 65a68d417b663..c0f089b9fdb67 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -667,11 +667,11 @@ static Value convertSparseMFMAVectorOperand(ConversionPatternRewriter &rewriter,
bool allowBf16 = true) {
Type inputType = input.getType();
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
- // bf16 -> i16 when not allowed (pre-gfx950)
+ // bf16 -> i16 when not allowed (pre-gfx950).
if (vectorType.getElementType().isBF16() && !allowBf16)
return LLVM::BitcastOp::create(
rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
- // i8/fp8 vectors -> vector<Nxi32>
+ // i8/fp8 vectors -> vector<Nxi32>.
if (isa<IntegerType>(vectorType.getElementType()) &&
vectorType.getElementTypeBitWidth() <= 8) {
int64_t numWords = llvm::divideCeil(
>From a68c02bd3684a3456605281b013f692b89375f48 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Fri, 12 Dec 2025 09:48:01 -0800
Subject: [PATCH 06/10] nits
Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 6 +-
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 69 +++++++++----------
.../Conversion/AMDGPUToROCDL/sparse-mfma.mlir | 20 +++---
3 files changed, 47 insertions(+), 48 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index b23343f5cdd1e..7d8c41b4c95cb 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -964,14 +964,16 @@ def SMFMACSparseInTypes : AnyTypeOf<[
VectorOfLengthAndType<[4, 8], [F16]>,
VectorOfLengthAndType<[4, 8], [BF16]>,
VectorOfLengthAndType<[8, 16], [I8]>,
- VectorOfLengthAndType<[8, 16], [F8E4M3FN, F8E5M2]>
+ VectorOfLengthAndType<[8, 16], [F8E4M3FN, F8E5M2]>,
+ VectorOfLengthAndType<[8, 16], [F8E4M3FNUZ, F8E5M2FNUZ]>
]>;
def SMFMACDenseInTypes : AnyTypeOf<[
VectorOfLengthAndType<[8, 16], [F16]>,
VectorOfLengthAndType<[8, 16], [BF16]>,
VectorOfLengthAndType<[16, 32], [I8]>,
- VectorOfLengthAndType<[16, 32], [F8E4M3FN, F8E5M2]>
+ VectorOfLengthAndType<[16, 32], [F8E4M3FN, F8E5M2]>,
+ VectorOfLengthAndType<[16, 32], [F8E4M3FNUZ, F8E5M2FNUZ]>
]>;
def SMFMACOutTypes : AnyTypeOf<[
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index c0f089b9fdb67..51388cbfa458e 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -666,21 +666,18 @@ static Value convertSparseMFMAVectorOperand(ConversionPatternRewriter &rewriter,
Location loc, Value input,
bool allowBf16 = true) {
Type inputType = input.getType();
- if (auto vectorType = dyn_cast<VectorType>(inputType)) {
- // bf16 -> i16 when not allowed (pre-gfx950).
- if (vectorType.getElementType().isBF16() && !allowBf16)
- return LLVM::BitcastOp::create(
- rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
- // i8/fp8 vectors -> vector<Nxi32>.
- if (isa<IntegerType>(vectorType.getElementType()) &&
- vectorType.getElementTypeBitWidth() <= 8) {
- int64_t numWords = llvm::divideCeil(
- vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
- 32);
- return LLVM::BitcastOp::create(
- rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
- input);
- }
+ auto vectorType = cast<VectorType>(inputType);
+ // bf16 -> i16 when not allowed (pre-gfx950).
+ if (vectorType.getElementType().isBF16() && !allowBf16)
+ return LLVM::BitcastOp::create(
+ rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
+ // i8/fp8 vectors -> vector<Nxi32>.
+ if (isa<IntegerType>(vectorType.getElementType()) &&
+ vectorType.getElementTypeBitWidth() <= 8) {
+ int64_t numWords = llvm::divideCeil(
+ vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32);
+ return LLVM::BitcastOp::create(
+ rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()), input);
}
return input;
}
@@ -1164,9 +1161,10 @@ static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType,
/// operation if one exists. This includes checking to ensure the intrinsic is
/// supported on the architecture you are compiling for.
static std::optional<StringRef> smfmacOpToIntrinsic(SparseMFMAOp op,
- bool isGfx950) {
- using fp8 = Float8E4M3FNType;
- using bf8 = Float8E5M2Type;
+ Chipset chipset) {
+ bool isGfx950 = chipset >= kGfx950;
+ auto isFp8 = [&](Type t) { return typeIsExpectedFp8ForChipset(chipset, t); };
+ auto isBf8 = [&](Type t) { return typeIsExpectedBf8ForChipset(chipset, t); };
uint32_t m = op.getM(), n = op.getN(), k = op.getK();
Type sourceAElem = getElementTypeOrSelf(op.getSourceA().getType());
@@ -1190,13 +1188,13 @@ static std::optional<StringRef> smfmacOpToIntrinsic(SparseMFMAOp op,
if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
destElem.isInteger(32))
return ROCDL::smfmac_i32_16x16x64_i8::getOperationName();
- if (isa<fp8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+ if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_16x16x64_fp8_fp8::getOperationName();
- if (isa<fp8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+ if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_16x16x64_fp8_bf8::getOperationName();
- if (isa<bf8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+ if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_16x16x64_bf8_fp8::getOperationName();
- if (isa<bf8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+ if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_16x16x64_bf8_bf8::getOperationName();
}
@@ -1204,13 +1202,13 @@ static std::optional<StringRef> smfmacOpToIntrinsic(SparseMFMAOp op,
if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
destElem.isInteger(32))
return ROCDL::smfmac_i32_16x16x128_i8::getOperationName();
- if (isa<fp8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+ if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_16x16x128_fp8_fp8::getOperationName();
- if (isa<fp8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+ if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_16x16x128_fp8_bf8::getOperationName();
- if (isa<bf8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+ if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_16x16x128_bf8_fp8::getOperationName();
- if (isa<bf8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+ if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_16x16x128_bf8_bf8::getOperationName();
}
@@ -1231,13 +1229,13 @@ static std::optional<StringRef> smfmacOpToIntrinsic(SparseMFMAOp op,
if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
destElem.isInteger(32))
return ROCDL::smfmac_i32_32x32x32_i8::getOperationName();
- if (isa<fp8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+ if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_32x32x32_fp8_fp8::getOperationName();
- if (isa<fp8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+ if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_32x32x32_fp8_bf8::getOperationName();
- if (isa<bf8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+ if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_32x32x32_bf8_fp8::getOperationName();
- if (isa<bf8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+ if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_32x32x32_bf8_bf8::getOperationName();
}
@@ -1245,13 +1243,13 @@ static std::optional<StringRef> smfmacOpToIntrinsic(SparseMFMAOp op,
if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
destElem.isInteger(32))
return ROCDL::smfmac_i32_32x32x64_i8::getOperationName();
- if (isa<fp8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+ if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_32x32x64_fp8_fp8::getOperationName();
- if (isa<fp8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+ if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_32x32x64_fp8_bf8::getOperationName();
- if (isa<bf8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+ if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_32x32x64_bf8_fp8::getOperationName();
- if (isa<bf8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+ if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_32x32x64_bf8_bf8::getOperationName();
}
@@ -1439,8 +1437,7 @@ struct SparseMFMAOpLowering : public ConvertOpToLLVMPattern<SparseMFMAOp> {
adaptor.getSourceB(), isGfx950);
Value c = adaptor.getDestC();
- std::optional<StringRef> maybeIntrinsic = smfmacOpToIntrinsic(op, isGfx950);
-
+ std::optional<StringRef> maybeIntrinsic = smfmacOpToIntrinsic(op, chipset);
if (!maybeIntrinsic.has_value())
return op.emitOpError(
"no intrinsic matching sparse MFMA on the given chipset");
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir
index 65a0cd3f1f87f..a1784ce95de49 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir
@@ -4,8 +4,8 @@ func.func @sparse_mfma_to_rocdl(%arg0 : vector<4xf16>, %arg1 : vector<8xf16>,
%arg4 : vector<4xbf16>, %arg5 : vector<8xbf16>,
%arg6 : vector<8xi8>, %arg7 : vector<16xi8>,
%arg8 : vector<4xi32>, %arg9 : vector<16xi32>,
- %arg10 : vector<8xf8E4M3FN>, %arg11 : vector<8xf8E5M2>,
- %arg12 : vector<16xf8E4M3FN>, %arg13 : vector<16xf8E5M2>,
+ %arg10 : vector<8xf8E4M3FNUZ>, %arg11 : vector<8xf8E5M2FNUZ>,
+ %arg12 : vector<16xf8E4M3FNUZ>, %arg13 : vector<16xf8E5M2FNUZ>,
%arg14 : i32) {
// CHECK: rocdl.smfmac.f32.16x16x32.f16{{.*}}: (vector<4xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
amdgpu.sparse_mfma 16x16x32 %arg0 * %arg1 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xf16>, vector<8xf16>, vector<4xf32>
@@ -29,33 +29,33 @@ func.func @sparse_mfma_to_rocdl(%arg0 : vector<4xf16>, %arg1 : vector<8xf16>,
// CHECK-NEXT: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32>
// CHECK-NEXT: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
// CHECK: rocdl.smfmac.f32.16x16x64.fp8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- amdgpu.sparse_mfma 16x16x64 %arg10 * %arg12 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FN>, vector<16xf8E4M3FN>, vector<4xf32>
+ amdgpu.sparse_mfma 16x16x64 %arg10 * %arg12 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E4M3FNUZ>, vector<4xf32>
// CHECK-NEXT: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32>
// CHECK-NEXT: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
// CHECK: rocdl.smfmac.f32.16x16x64.bf8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- amdgpu.sparse_mfma 16x16x64 %arg11 * %arg13 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2>, vector<16xf8E5M2>, vector<4xf32>
+ amdgpu.sparse_mfma 16x16x64 %arg11 * %arg13 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E5M2FNUZ>, vector<4xf32>
// CHECK: rocdl.smfmac.f32.16x16x64.fp8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- amdgpu.sparse_mfma 16x16x64 %arg10 * %arg13 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FN>, vector<16xf8E5M2>, vector<4xf32>
+ amdgpu.sparse_mfma 16x16x64 %arg10 * %arg13 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E5M2FNUZ>, vector<4xf32>
// CHECK: rocdl.smfmac.f32.16x16x64.bf8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- amdgpu.sparse_mfma 16x16x64 %arg11 * %arg12 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2>, vector<16xf8E4M3FN>, vector<4xf32>
+ amdgpu.sparse_mfma 16x16x64 %arg11 * %arg12 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E4M3FNUZ>, vector<4xf32>
// CHECK: rocdl.smfmac.i32.32x32x32.i8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
amdgpu.sparse_mfma 32x32x32 %arg6 * %arg7 + %arg9 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xi8>, vector<16xi8>, vector<16xi32>
// CHECK: rocdl.smfmac.f32.32x32x32.fp8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- amdgpu.sparse_mfma 32x32x32 %arg10 * %arg12 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FN>, vector<16xf8E4M3FN>, vector<16xf32>
+ amdgpu.sparse_mfma 32x32x32 %arg10 * %arg12 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E4M3FNUZ>, vector<16xf32>
// CHECK: rocdl.smfmac.f32.32x32x32.bf8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- amdgpu.sparse_mfma 32x32x32 %arg11 * %arg13 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2>, vector<16xf8E5M2>, vector<16xf32>
+ amdgpu.sparse_mfma 32x32x32 %arg11 * %arg13 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E5M2FNUZ>, vector<16xf32>
// CHECK: rocdl.smfmac.f32.32x32x32.fp8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- amdgpu.sparse_mfma 32x32x32 %arg10 * %arg13 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FN>, vector<16xf8E5M2>, vector<16xf32>
+ amdgpu.sparse_mfma 32x32x32 %arg10 * %arg13 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E5M2FNUZ>, vector<16xf32>
// CHECK: rocdl.smfmac.f32.32x32x32.bf8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- amdgpu.sparse_mfma 32x32x32 %arg11 * %arg12 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2>, vector<16xf8E4M3FN>, vector<16xf32>
+ amdgpu.sparse_mfma 32x32x32 %arg11 * %arg12 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E4M3FNUZ>, vector<16xf32>
func.return
}
>From 407862a654fa673579d448ea160c99bdc52c19e6 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Fri, 12 Dec 2025 09:49:02 -0800
Subject: [PATCH 07/10] nit
Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 51388cbfa458e..66136ab547022 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1426,7 +1426,7 @@ struct SparseMFMAOpLowering : public ConvertOpToLLVMPattern<SparseMFMAOp> {
if (!outType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
- // smfmac is supported on gfx942 and gfx950
+ // smfmac is supported on gfx942 and gfx950.
if (chipset.majorVersion != 9 || chipset < kGfx942)
return op->emitOpError("sparse MFMA (smfmac) only supported on gfx942+");
bool isGfx950 = chipset >= kGfx950;
>From 6ad2e3ac4594ade6df985320ef466c9602244061 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Tue, 16 Dec 2025 14:52:01 -0800
Subject: [PATCH 08/10] format
Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 11 +++++------
1 file changed, 5 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index f1e74d4c6f61c..b538e09441c83 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -3530,12 +3530,11 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
SparseMFMAOpLowering, WMMAOpLowering, ScaledWMMAOpLowering,
- ExtPackedFp8OpLowering,
- ScaledExtPackedMatrixOpLowering, ScaledExtPackedOpLowering,
- PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
- PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
- TransposeLoadOpLowering, AMDGPUPermlaneLowering,
- AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
+ ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering,
+ ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
+ PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
+ GatherToLDSOpLowering, TransposeLoadOpLowering,
+ AMDGPUPermlaneLowering, AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
AMDGPULowerDescriptor<MakeDmaDescriptorOp>,
AMDGPULowerDescriptor<MakeGatherDmaDescriptorOp>,
>From 46fa5061c5ef8b05d10e7fe3cc3244e1f585dc08 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Thu, 18 Dec 2025 12:20:41 -0800
Subject: [PATCH 09/10] address review
Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 27 ++++----
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 6 +-
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 18 ++++++
.../AMDGPUToROCDL/sparse-mfma-gfx950.mlir | 38 ++++++-----
.../Conversion/AMDGPUToROCDL/sparse-mfma.mlir | 48 +++++++-------
mlir/test/Dialect/AMDGPU/invalid.mlir | 64 +++++++++++++++++++
6 files changed, 151 insertions(+), 50 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index dbc01c33f2853..a2bb755f40b16 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -1013,6 +1013,11 @@ def SMFMACOutTypes : AnyTypeOf<[
VectorOfLengthAndType<[4, 16], [I32]>
]>;
+def SparseMFMAIdxTypes : AnyTypeOf<[
+ FixedVectorOfLengthAndType<[4], [I8]>,
+ FixedVectorOfLengthAndType<[2], [I16]>
+]>;
+
// scaled_mfma
def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN]>,
VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
@@ -1171,7 +1176,7 @@ def AMDGPU_SparseMFMAOp :
SMFMACSparseInTypes:$sourceA,
SMFMACDenseInTypes:$sourceB,
SMFMACOutTypes:$destC,
- I32:$sparseIdx,
+ SparseMFMAIdxTypes:$sparseIdx,
DefaultValuedAttr<I32Attr, "0">:$cbsz,
DefaultValuedAttr<I32Attr, "0">:$abid)>,
Results<(outs SMFMACOutTypes: $destD)> {
@@ -1190,10 +1195,10 @@ def AMDGPU_SparseMFMAOp :
- M=N=16, K=64 and M=N=32, K=32 for f16 and bf16 sources
- M=N=16, K=128 and M=N=32, K=64 for i8 and fp8 sources
- The `sparseIdx` parameter (i32) contains packed indices identifying the
- positions of non-zero elements in the 2:4 sparse matrix A. For 16-bit data,
- this uses four groups of 8-bit indices; for 8-bit data, 2 groups of 16-bit
- indices.
+ The `sparseIdx` parameter contains packed indices identifying the positions
+ of non-zero elements in the 2:4 sparse matrix A. For 16-bit source data,
+ use `vector<4xi8>` (four 8-bit indices). For 8-bit source data, use
+ `vector<2xi16>` (two 16-bit indices).
The `cbsz` and `abid` parameters are repurposed to select the index set.
If `cbsz == 0`, then `abid[1:0]` selects which index set to use.
@@ -1201,20 +1206,20 @@ def AMDGPU_SparseMFMAOp :
Example:
```mlir
- %0 = amdgpu.sparse_mfma 16x16x32 %matA * %matB + %matC sparse(%idx)
+ %0 = amdgpu.sparse_mfma 16x16x32 %matA * %matB + %matC sparse(%idx : vector<4xi8>)
: vector<4xf16>, vector<8xf16>, vector<4xf32>
- %1 = amdgpu.sparse_mfma 16x16x64 %matA * %matB + %matC sparse(%idx)
- : vector<8xf16>, vector<16xf16>, vector<4xf32>
+ %1 = amdgpu.sparse_mfma 16x16x64 %matA * %matB + %matC sparse(%idx : vector<2xi16>)
+ : vector<8xi8>, vector<16xi8>, vector<4xi32>
- %2 = amdgpu.sparse_mfma 16x16x128 %matA * %matB + %matC sparse(%idx)
+ %2 = amdgpu.sparse_mfma 16x16x64 %matA * %matB + %matC sparse(%idx : vector<2xi16>)
{ cbsz = 0 : i32, abid = 1 : i32 }
- : vector<4xi32>, vector<8xi32>, vector<4xi32>
+ : vector<8xf8E4M3FNUZ>, vector<16xf8E4M3FNUZ>, vector<4xf32>
```
}];
let assemblyFormat = [{
custom<MNKDimensionList>($m, $n, $k) $sourceA `*` $sourceB `+` $destC
- `sparse` `(` $sparseIdx `)`
+ `sparse` `(` $sparseIdx `:` type($sparseIdx) `)`
attr-dict
`:` type($sourceA) `,` type($sourceB) `,` type($destC)
}];
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index b538e09441c83..5dcd24019412a 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1477,9 +1477,13 @@ struct SparseMFMAOpLowering : public ConvertOpToLLVMPattern<SparseMFMAOp> {
return op.emitOpError(
"no intrinsic matching sparse MFMA on the given chipset");
+ // Bitcast sparse indices from vector<4xi8> or vector<2xi16> to i32.
+ Value sparseIdx = LLVM::BitcastOp::create(
+ rewriter, loc, rewriter.getI32Type(), adaptor.getSparseIdx());
+
OperationState loweredOp(loc, maybeIntrinsic.value());
loweredOp.addTypes(outType);
- loweredOp.addOperands({a, b, c, adaptor.getSparseIdx(),
+ loweredOp.addOperands({a, b, c, sparseIdx,
createI32Constant(rewriter, loc, op.getCbsz()),
createI32Constant(rewriter, loc, op.getAbid())});
Value lowered = rewriter.create(loweredOp)->getResult(0);
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 8bcdaec0bf3b1..31f87aaa3ce74 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -671,6 +671,24 @@ LogicalResult SparseMFMAOp::verify() {
if (getCbsz() == 0 && !is8BitSource && getAbid() > 3)
return emitOpError("ABID must be between 0 and 3 for 16-bit source data");
+ // Validate sparseIdx type matches source element type.
+ auto sparseIdxType = cast<VectorType>(getSparseIdx().getType());
+ if (is8BitSource) {
+ // 8-bit source data requires vector<2xi16> sparse indices.
+ if (sparseIdxType.getNumElements() != 2 ||
+ !sparseIdxType.getElementType().isInteger(16))
+ return emitOpError("expected vector<2xi16> sparse indices for 8-bit "
+ "source data, but got ")
+ << sparseIdxType;
+ } else {
+ // 16-bit source data requires vector<4xi8> sparse indices.
+ if (sparseIdxType.getNumElements() != 4 ||
+ !sparseIdxType.getElementType().isInteger(8))
+ return emitOpError("expected vector<4xi8> sparse indices for 16-bit "
+ "source data, but got ")
+ << sparseIdxType;
+ }
+
int64_t expectedSourceElems = (getM() * getK()) / waveSize;
if (denseLen != expectedSourceElems)
return emitOpError("expected " + Twine(expectedSourceElems) +
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma-gfx950.mlir b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma-gfx950.mlir
index abe2565f7c41b..266e0e7e15595 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma-gfx950.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma-gfx950.mlir
@@ -6,48 +6,56 @@ func.func @sparse_mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf16>,
%arg8 : vector<4xi32>, %arg9 : vector<16xi32>,
%arg10 : vector<16xf8E4M3FN>, %arg11 : vector<16xf8E5M2>,
%arg12 : vector<32xf8E4M3FN>, %arg13 : vector<32xf8E5M2>,
- %arg14 : i32) {
+ %arg14 : vector<4xi8>, %arg15 : vector<2xi16>) {
+ // CHECK: llvm.bitcast %{{.*}} : vector<4xi8> to i32
// CHECK: rocdl.smfmac.f32.16x16x64.f16{{.*}}: (vector<8xf16>, vector<16xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- amdgpu.sparse_mfma 16x16x64 %arg0 * %arg1 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf16>, vector<16xf16>, vector<4xf32>
+ amdgpu.sparse_mfma 16x16x64 %arg0 * %arg1 + %arg2 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf16>, vector<16xf16>, vector<4xf32>
// CHECK: rocdl.smfmac.f32.16x16x64.bf16{{.*}}: (vector<8xbf16>, vector<16xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- amdgpu.sparse_mfma 16x16x64 %arg4 * %arg5 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xbf16>, vector<16xbf16>, vector<4xf32>
+ amdgpu.sparse_mfma 16x16x64 %arg4 * %arg5 + %arg2 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xbf16>, vector<16xbf16>, vector<4xf32>
+ // CHECK: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
+ // CHECK: llvm.bitcast {{.*}} : vector<32xi8> to vector<8xi32>
+ // CHECK: llvm.bitcast %{{.*}} : vector<2xi16> to i32
// CHECK: rocdl.smfmac.i32.16x16x128.i8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
- amdgpu.sparse_mfma 16x16x128 %arg6 * %arg7 + %arg8 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xi8>, vector<32xi8>, vector<4xi32>
+ amdgpu.sparse_mfma 16x16x128 %arg6 * %arg7 + %arg8 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xi8>, vector<32xi8>, vector<4xi32>
+ // CHECK: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
+ // CHECK: llvm.bitcast {{.*}} : vector<32xi8> to vector<8xi32>
// CHECK: rocdl.smfmac.f32.16x16x128.fp8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- amdgpu.sparse_mfma 16x16x128 %arg10 * %arg12 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E4M3FN>, vector<4xf32>
+ amdgpu.sparse_mfma 16x16x128 %arg10 * %arg12 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E4M3FN>, vector<4xf32>
+ // CHECK: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
+ // CHECK: llvm.bitcast {{.*}} : vector<32xi8> to vector<8xi32>
// CHECK: rocdl.smfmac.f32.16x16x128.bf8.bf8 {{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- amdgpu.sparse_mfma 16x16x128 %arg11 * %arg13 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E5M2>, vector<4xf32>
+ amdgpu.sparse_mfma 16x16x128 %arg11 * %arg13 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E5M2>, vector<4xf32>
// CHECK: rocdl.smfmac.f32.16x16x128.fp8.bf8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- amdgpu.sparse_mfma 16x16x128 %arg10 * %arg13 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E5M2>, vector<4xf32>
+ amdgpu.sparse_mfma 16x16x128 %arg10 * %arg13 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E5M2>, vector<4xf32>
// CHECK: rocdl.smfmac.f32.16x16x128.bf8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- amdgpu.sparse_mfma 16x16x128 %arg11 * %arg12 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E4M3FN>, vector<4xf32>
+ amdgpu.sparse_mfma 16x16x128 %arg11 * %arg12 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E4M3FN>, vector<4xf32>
// CHECK: rocdl.smfmac.f32.32x32x32.f16{{.*}}: (vector<8xf16>, vector<16xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- amdgpu.sparse_mfma 32x32x32 %arg0 * %arg1 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf16>, vector<16xf16>, vector<16xf32>
+ amdgpu.sparse_mfma 32x32x32 %arg0 * %arg1 + %arg3 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf16>, vector<16xf16>, vector<16xf32>
// CHECK: rocdl.smfmac.f32.32x32x32.bf16{{.*}}: (vector<8xbf16>, vector<16xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- amdgpu.sparse_mfma 32x32x32 %arg4 * %arg5 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xbf16>, vector<16xbf16>, vector<16xf32>
+ amdgpu.sparse_mfma 32x32x32 %arg4 * %arg5 + %arg3 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xbf16>, vector<16xbf16>, vector<16xf32>
// CHECK: rocdl.smfmac.i32.32x32x64.i8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
- amdgpu.sparse_mfma 32x32x64 %arg6 * %arg7 + %arg9 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xi8>, vector<32xi8>, vector<16xi32>
+ amdgpu.sparse_mfma 32x32x64 %arg6 * %arg7 + %arg9 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xi8>, vector<32xi8>, vector<16xi32>
// CHECK: rocdl.smfmac.f32.32x32x64.fp8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- amdgpu.sparse_mfma 32x32x64 %arg10 * %arg12 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E4M3FN>, vector<16xf32>
+ amdgpu.sparse_mfma 32x32x64 %arg10 * %arg12 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E4M3FN>, vector<16xf32>
// CHECK: rocdl.smfmac.f32.32x32x64.bf8.bf8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- amdgpu.sparse_mfma 32x32x64 %arg11 * %arg13 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E5M2>, vector<16xf32>
+ amdgpu.sparse_mfma 32x32x64 %arg11 * %arg13 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E5M2>, vector<16xf32>
// CHECK: rocdl.smfmac.f32.32x32x64.fp8.bf8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- amdgpu.sparse_mfma 32x32x64 %arg10 * %arg13 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E5M2>, vector<16xf32>
+ amdgpu.sparse_mfma 32x32x64 %arg10 * %arg13 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E5M2>, vector<16xf32>
// CHECK: rocdl.smfmac.f32.32x32x64.bf8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- amdgpu.sparse_mfma 32x32x64 %arg11 * %arg12 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E4M3FN>, vector<16xf32>
+ amdgpu.sparse_mfma 32x32x64 %arg11 * %arg12 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E4M3FN>, vector<16xf32>
func.return
}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir
index a1784ce95de49..b2c91c3d9bed1 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir
@@ -6,56 +6,58 @@ func.func @sparse_mfma_to_rocdl(%arg0 : vector<4xf16>, %arg1 : vector<8xf16>,
%arg8 : vector<4xi32>, %arg9 : vector<16xi32>,
%arg10 : vector<8xf8E4M3FNUZ>, %arg11 : vector<8xf8E5M2FNUZ>,
%arg12 : vector<16xf8E4M3FNUZ>, %arg13 : vector<16xf8E5M2FNUZ>,
- %arg14 : i32) {
+ %arg14 : vector<4xi8>, %arg15 : vector<2xi16>) {
+ // CHECK: llvm.bitcast %{{.*}} : vector<4xi8> to i32
// CHECK: rocdl.smfmac.f32.16x16x32.f16{{.*}}: (vector<4xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- amdgpu.sparse_mfma 16x16x32 %arg0 * %arg1 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xf16>, vector<8xf16>, vector<4xf32>
+ amdgpu.sparse_mfma 16x16x32 %arg0 * %arg1 + %arg2 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xf16>, vector<8xf16>, vector<4xf32>
- // CHECK-NEXT: llvm.bitcast {{.*}} : vector<4xbf16> to vector<4xi16>
- // CHECK-NEXT: llvm.bitcast {{.*}} : vector<8xbf16> to vector<8xi16>
+ // CHECK: llvm.bitcast {{.*}} : vector<4xbf16> to vector<4xi16>
+ // CHECK: llvm.bitcast {{.*}} : vector<8xbf16> to vector<8xi16>
// CHECK: rocdl.smfmac.f32.16x16x32.bf16 {{.*}}: (vector<4xi16>, vector<8xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- amdgpu.sparse_mfma 16x16x32 %arg4 * %arg5 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xbf16>, vector<8xbf16>, vector<4xf32>
+ amdgpu.sparse_mfma 16x16x32 %arg4 * %arg5 + %arg2 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xbf16>, vector<8xbf16>, vector<4xf32>
// CHECK: rocdl.smfmac.f32.32x32x16.f16{{.*}}: (vector<4xf16>, vector<8xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- amdgpu.sparse_mfma 32x32x16 %arg0 * %arg1 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xf16>, vector<8xf16>, vector<16xf32>
+ amdgpu.sparse_mfma 32x32x16 %arg0 * %arg1 + %arg3 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xf16>, vector<8xf16>, vector<16xf32>
// CHECK: rocdl.smfmac.f32.32x32x16.bf16 {{.*}}: (vector<4xi16>, vector<8xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- amdgpu.sparse_mfma 32x32x16 %arg4 * %arg5 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xbf16>, vector<8xbf16>, vector<16xf32>
+ amdgpu.sparse_mfma 32x32x16 %arg4 * %arg5 + %arg3 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xbf16>, vector<8xbf16>, vector<16xf32>
- // CHECK-NEXT: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32>
- // CHECK-NEXT: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
+ // CHECK: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32>
+ // CHECK: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
+ // CHECK: llvm.bitcast %{{.*}} : vector<2xi16> to i32
// CHECK: rocdl.smfmac.i32.16x16x64.i8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
- amdgpu.sparse_mfma 16x16x64 %arg6 * %arg7 + %arg8 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xi8>, vector<16xi8>, vector<4xi32>
+ amdgpu.sparse_mfma 16x16x64 %arg6 * %arg7 + %arg8 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xi8>, vector<16xi8>, vector<4xi32>
- // CHECK-NEXT: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32>
- // CHECK-NEXT: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
+ // CHECK: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32>
+ // CHECK: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
// CHECK: rocdl.smfmac.f32.16x16x64.fp8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- amdgpu.sparse_mfma 16x16x64 %arg10 * %arg12 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E4M3FNUZ>, vector<4xf32>
+ amdgpu.sparse_mfma 16x16x64 %arg10 * %arg12 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E4M3FNUZ>, vector<4xf32>
- // CHECK-NEXT: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32>
- // CHECK-NEXT: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
+ // CHECK: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32>
+ // CHECK: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
// CHECK: rocdl.smfmac.f32.16x16x64.bf8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- amdgpu.sparse_mfma 16x16x64 %arg11 * %arg13 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E5M2FNUZ>, vector<4xf32>
+ amdgpu.sparse_mfma 16x16x64 %arg11 * %arg13 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E5M2FNUZ>, vector<4xf32>
// CHECK: rocdl.smfmac.f32.16x16x64.fp8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- amdgpu.sparse_mfma 16x16x64 %arg10 * %arg13 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E5M2FNUZ>, vector<4xf32>
+ amdgpu.sparse_mfma 16x16x64 %arg10 * %arg13 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E5M2FNUZ>, vector<4xf32>
// CHECK: rocdl.smfmac.f32.16x16x64.bf8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- amdgpu.sparse_mfma 16x16x64 %arg11 * %arg12 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E4M3FNUZ>, vector<4xf32>
+ amdgpu.sparse_mfma 16x16x64 %arg11 * %arg12 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E4M3FNUZ>, vector<4xf32>
// CHECK: rocdl.smfmac.i32.32x32x32.i8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
- amdgpu.sparse_mfma 32x32x32 %arg6 * %arg7 + %arg9 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xi8>, vector<16xi8>, vector<16xi32>
+ amdgpu.sparse_mfma 32x32x32 %arg6 * %arg7 + %arg9 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xi8>, vector<16xi8>, vector<16xi32>
// CHECK: rocdl.smfmac.f32.32x32x32.fp8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- amdgpu.sparse_mfma 32x32x32 %arg10 * %arg12 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E4M3FNUZ>, vector<16xf32>
+ amdgpu.sparse_mfma 32x32x32 %arg10 * %arg12 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E4M3FNUZ>, vector<16xf32>
// CHECK: rocdl.smfmac.f32.32x32x32.bf8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- amdgpu.sparse_mfma 32x32x32 %arg11 * %arg13 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E5M2FNUZ>, vector<16xf32>
+ amdgpu.sparse_mfma 32x32x32 %arg11 * %arg13 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E5M2FNUZ>, vector<16xf32>
// CHECK: rocdl.smfmac.f32.32x32x32.fp8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- amdgpu.sparse_mfma 32x32x32 %arg10 * %arg13 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E5M2FNUZ>, vector<16xf32>
+ amdgpu.sparse_mfma 32x32x32 %arg10 * %arg13 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E5M2FNUZ>, vector<16xf32>
// CHECK: rocdl.smfmac.f32.32x32x32.bf8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- amdgpu.sparse_mfma 32x32x32 %arg11 * %arg12 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E4M3FNUZ>, vector<16xf32>
+ amdgpu.sparse_mfma 32x32x32 %arg11 * %arg12 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E4M3FNUZ>, vector<16xf32>
func.return
}
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 9ece57e9ec6a3..6785946fa60b1 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -452,3 +452,67 @@ func.func @make_gather_dma_descriptor_invalid_index_types(%base: !amdgpu.tdm_gat
amdgpu.make_gather_dma_descriptor %base[%indices] globalSize [4, 4] globalStride [1, 1] sharedSize [1, 2] : !amdgpu.tdm_gather_base<i32, i16>, vector<8xi32> -> !amdgpu.tdm_descriptor
func.return
}
+
+// -----
+
+func.func @sparse_mfma_dense_not_double_sparse(%a: vector<4xf16>, %b: vector<4xf16>, %c: vector<4xf32>, %idx: vector<4xi8>) -> vector<4xf32> {
+ // expected-error at +1 {{'amdgpu.sparse_mfma' op operand #1 must be vector of 16-bit float values of length 8/16 or vector of bfloat16 type values of length 8/16 or vector of 8-bit signless integer values of length 16/32 or vector of f8E4M3FN type or f8E5M2 type values of length 16/32 or vector of f8E4M3FNUZ type or f8E5M2FNUZ type values of length 16/32, but got 'vector<4xf16>'}}
+ %d = amdgpu.sparse_mfma 16x16x32 %a * %b + %c sparse(%idx : vector<4xi8>) : vector<4xf16>, vector<4xf16>, vector<4xf32>
+ func.return %d : vector<4xf32>
+}
+
+// -----
+
+func.func @sparse_mfma_mismatched_source_types(%a: vector<4xf16>, %b: vector<8xbf16>, %c: vector<4xf32>, %idx: vector<4xi8>) -> vector<4xf32> {
+ // expected-error at +1 {{'amdgpu.sparse_mfma' op expected source operands to have the same element type}}
+ %d = amdgpu.sparse_mfma 16x16x32 %a * %b + %c sparse(%idx : vector<4xi8>) : vector<4xf16>, vector<8xbf16>, vector<4xf32>
+ func.return %d : vector<4xf32>
+}
+
+// -----
+
+func.func @sparse_mfma_abid_invalid_for_8bit(%a: vector<8xi8>, %b: vector<16xi8>, %c: vector<4xi32>, %idx: vector<2xi16>) -> vector<4xi32> {
+ // expected-error at +1 {{'amdgpu.sparse_mfma' op ABID must be 0 or 1 for 8-bit source data}}
+ %d = amdgpu.sparse_mfma 16x16x64 %a * %b + %c sparse(%idx : vector<2xi16>) { abid = 2 : i32, cbsz = 0 : i32 } : vector<8xi8>, vector<16xi8>, vector<4xi32>
+ func.return %d : vector<4xi32>
+}
+
+// -----
+
+func.func @sparse_mfma_abid_invalid_for_16bit(%a: vector<4xf16>, %b: vector<8xf16>, %c: vector<4xf32>, %idx: vector<4xi8>) -> vector<4xf32> {
+ // expected-error at +1 {{'amdgpu.sparse_mfma' op ABID must be between 0 and 3 for 16-bit source data}}
+ %d = amdgpu.sparse_mfma 16x16x32 %a * %b + %c sparse(%idx : vector<4xi8>) { abid = 4 : i32, cbsz = 0 : i32 } : vector<4xf16>, vector<8xf16>, vector<4xf32>
+ func.return %d : vector<4xf32>
+}
+
+// -----
+
+func.func @sparse_mfma_wrong_idx_type_for_8bit(%a: vector<8xi8>, %b: vector<16xi8>, %c: vector<4xi32>, %idx: vector<4xi8>) -> vector<4xi32> {
+ // expected-error at +1 {{'amdgpu.sparse_mfma' op expected vector<2xi16> sparse indices for 8-bit source data, but got}}
+ %d = amdgpu.sparse_mfma 16x16x64 %a * %b + %c sparse(%idx : vector<4xi8>) : vector<8xi8>, vector<16xi8>, vector<4xi32>
+ func.return %d : vector<4xi32>
+}
+
+// -----
+
+func.func @sparse_mfma_wrong_idx_type_for_16bit(%a: vector<4xf16>, %b: vector<8xf16>, %c: vector<4xf32>, %idx: vector<2xi16>) -> vector<4xf32> {
+ // expected-error at +1 {{'amdgpu.sparse_mfma' op expected vector<4xi8> sparse indices for 16-bit source data, but got}}
+ %d = amdgpu.sparse_mfma 16x16x32 %a * %b + %c sparse(%idx : vector<2xi16>) : vector<4xf16>, vector<8xf16>, vector<4xf32>
+ func.return %d : vector<4xf32>
+}
+
+// -----
+
+func.func @sparse_mfma_wrong_source_count(%a: vector<4xf16>, %b: vector<8xf16>, %c: vector<16xf32>, %idx: vector<4xi8>) -> vector<16xf32> {
+ // expected-error at +1 {{'amdgpu.sparse_mfma' op expected 16 source values for this operation but got 8}}
+ %d = amdgpu.sparse_mfma 32x32x32 %a * %b + %c sparse(%idx : vector<4xi8>) : vector<4xf16>, vector<8xf16>, vector<16xf32>
+ func.return %d : vector<16xf32>
+}
+
+// -----
+
+func.func @sparse_mfma_wrong_dest_count(%a: vector<4xf16>, %b: vector<8xf16>, %c: vector<16xf32>, %idx: vector<4xi8>) -> vector<16xf32> {
+ // expected-error at +1 {{'amdgpu.sparse_mfma' op expected 4 result values for this operation but got 16}}
+ %d = amdgpu.sparse_mfma 16x16x32 %a * %b + %c sparse(%idx : vector<4xi8>) : vector<4xf16>, vector<8xf16>, vector<16xf32>
+ func.return %d : vector<16xf32>
+}
>From 8f9416da45ea2830ead6c0624c8b0240b9e9c946 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Thu, 18 Dec 2025 14:25:58 -0800
Subject: [PATCH 10/10] nits
Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index a2bb755f40b16..8565a6b727fd1 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -1013,7 +1013,7 @@ def SMFMACOutTypes : AnyTypeOf<[
VectorOfLengthAndType<[4, 16], [I32]>
]>;
-def SparseMFMAIdxTypes : AnyTypeOf<[
+def SMFMACIdxTypes : AnyTypeOf<[
FixedVectorOfLengthAndType<[4], [I8]>,
FixedVectorOfLengthAndType<[2], [I16]>
]>;
@@ -1176,7 +1176,7 @@ def AMDGPU_SparseMFMAOp :
SMFMACSparseInTypes:$sourceA,
SMFMACDenseInTypes:$sourceB,
SMFMACOutTypes:$destC,
- SparseMFMAIdxTypes:$sparseIdx,
+ SMFMACIdxTypes:$sparseIdx,
DefaultValuedAttr<I32Attr, "0">:$cbsz,
DefaultValuedAttr<I32Attr, "0">:$abid)>,
Results<(outs SMFMACOutTypes: $destD)> {
More information about the Mlir-commits
mailing list