[Mlir-commits] [mlir] 24c7b4e - [mlir][amdgpu] implement amdgpu.sparse_mfma wrapper for smfmac instructions (#171968)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 18 18:16:18 PST 2025
Author: Eric Feng
Date: 2025-12-18T20:16:14-06:00
New Revision: 24c7b4ea4883dc37369c6c31565062817c78b328
URL: https://github.com/llvm/llvm-project/commit/24c7b4ea4883dc37369c6c31565062817c78b328
DIFF: https://github.com/llvm/llvm-project/commit/24c7b4ea4883dc37369c6c31565062817c78b328.diff
LOG: [mlir][amdgpu] implement amdgpu.sparse_mfma wrapper for smfmac instructions (#171968)
Signed-off-by: Eric Feng <Eric.Feng at amd.com>
Added:
mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma-gfx950.mlir
mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir
Modified:
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
mlir/test/Dialect/AMDGPU/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 8834f7c489496..7a8cd89f886a7 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -1074,6 +1074,34 @@ def MFMAOutTypes : AnyTypeOf<[F64,
VectorOfLengthAndType<[4, 16, 32], [F32]>,
VectorOfLengthAndType<[4, 16, 32], [I32]>,
VectorOfLengthAndType<[4], [F64]>]>;
+
+// sparse_mfma (smfmac)
+def SMFMACSparseInTypes : AnyTypeOf<[
+ VectorOfLengthAndType<[4, 8], [F16]>,
+ VectorOfLengthAndType<[4, 8], [BF16]>,
+ VectorOfLengthAndType<[8, 16], [I8]>,
+ VectorOfLengthAndType<[8, 16], [F8E4M3FN, F8E5M2]>,
+ VectorOfLengthAndType<[8, 16], [F8E4M3FNUZ, F8E5M2FNUZ]>
+]>;
+
+def SMFMACDenseInTypes : AnyTypeOf<[
+ VectorOfLengthAndType<[8, 16], [F16]>,
+ VectorOfLengthAndType<[8, 16], [BF16]>,
+ VectorOfLengthAndType<[16, 32], [I8]>,
+ VectorOfLengthAndType<[16, 32], [F8E4M3FN, F8E5M2]>,
+ VectorOfLengthAndType<[16, 32], [F8E4M3FNUZ, F8E5M2FNUZ]>
+]>;
+
+def SMFMACOutTypes : AnyTypeOf<[
+ VectorOfLengthAndType<[4, 16], [F32]>,
+ VectorOfLengthAndType<[4, 16], [I32]>
+]>;
+
+def SMFMACIdxTypes : AnyTypeOf<[
+ FixedVectorOfLengthAndType<[4], [I8]>,
+ FixedVectorOfLengthAndType<[2], [I16]>
+]>;
+
// scaled_mfma
def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN]>,
VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
@@ -1222,6 +1250,66 @@ def AMDGPU_WMMAOp :
let hasVerifier = 1;
}
+def AMDGPU_SparseMFMAOp :
+ AMDGPU_Op<"sparse_mfma", [AllTypesMatch<["destC", "destD"]>,
+ Pure]>,
+ Arguments<(ins
+ ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$m,
+ ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$n,
+ ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32, 64, 128]>]>:$k,
+ SMFMACSparseInTypes:$sourceA,
+ SMFMACDenseInTypes:$sourceB,
+ SMFMACOutTypes:$destC,
+ SMFMACIdxTypes:$sparseIdx,
+ DefaultValuedAttr<I32Attr, "0">:$cbsz,
+ DefaultValuedAttr<I32Attr, "0">:$abid)>,
+ Results<(outs SMFMACOutTypes: $destD)> {
+ let summary = "MLIR wrapper for CDNA sparse mfma (smfmac) instructions";
+ let description = [{
+ The `amdgpu.sparse_mfma` op is an MLIR wrapper around intrinsics for various
+ `smfmac` instructions in the AMDGPU architecture, which perform matrix
+ multiply-accumulate operations using 2:4 structured sparsity on matrix A
+ with dense matrices B, C, and D.
+
+ On gfx942, smfmac intrinsics support:
+ - M=N=16, K=32 and M=N=32, K=16 for f16 and bf16 sources
+ - M=N=16, K=64 and M=N=32, K=32 for i8 and fp8 sources
+
+ On gfx950, smfmac intrinsics additionally support:
+ - M=N=16, K=64 and M=N=32, K=32 for f16 and bf16 sources
+ - M=N=16, K=128 and M=N=32, K=64 for i8 and fp8 sources
+
+ The `sparseIdx` parameter contains packed indices identifying the positions
+ of non-zero elements in the 2:4 sparse matrix A. For 16-bit source data,
+ use `vector<4xi8>` (four 8-bit indices). For 8-bit source data, use
+ `vector<2xi16>` (two 16-bit indices).
+
+ The `cbsz` and `abid` parameters are repurposed to select the index set.
+ If `cbsz == 0`, then `abid[1:0]` selects which index set to use.
+ If `cbsz != 0`, then the very first is selected.
+
+ Example:
+ ```mlir
+ %0 = amdgpu.sparse_mfma 16x16x32 %matA * %matB + %matC sparse(%idx : vector<4xi8>)
+ : vector<4xf16>, vector<8xf16>, vector<4xf32>
+
+ %1 = amdgpu.sparse_mfma 16x16x64 %matA * %matB + %matC sparse(%idx : vector<2xi16>)
+ : vector<8xi8>, vector<16xi8>, vector<4xi32>
+
+ %2 = amdgpu.sparse_mfma 16x16x64 %matA * %matB + %matC sparse(%idx : vector<2xi16>)
+ { cbsz = 0 : i32, abid = 1 : i32 }
+ : vector<8xf8E4M3FNUZ>, vector<16xf8E4M3FNUZ>, vector<4xf32>
+ ```
+ }];
+ let assemblyFormat = [{
+ custom<MNKDimensionList>($m, $n, $k) $sourceA `*` $sourceB `+` $destC
+ `sparse` `(` $sparseIdx `:` type($sparseIdx) `)`
+ attr-dict
+ `:` type($sourceA) `,` type($sourceB) `,` type($destC)
+ }];
+ let hasVerifier = 1;
+}
+
def AMDGPU_GatherToLDSOp :
AMDGPU_Op<"gather_to_lds", [AttrSizedOperandSegments]>,
Arguments<(ins
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 90009c9722fe3..9b31fc69bed35 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -661,6 +661,27 @@ static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter,
return input;
}
+/// Converts sparse MFMA (smfmac) operands to the expected ROCDL types.
+static Value convertSparseMFMAVectorOperand(ConversionPatternRewriter &rewriter,
+ Location loc, Value input,
+ bool allowBf16 = true) {
+ Type inputType = input.getType();
+ auto vectorType = cast<VectorType>(inputType);
+ // bf16 -> i16 when not allowed (pre-gfx950).
+ if (vectorType.getElementType().isBF16() && !allowBf16)
+ return LLVM::BitcastOp::create(
+ rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
+ // i8/fp8 vectors -> vector<Nxi32>.
+ if (isa<IntegerType>(vectorType.getElementType()) &&
+ vectorType.getElementTypeBitWidth() <= 8) {
+ int64_t numWords = llvm::divideCeil(
+ vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32);
+ return LLVM::BitcastOp::create(
+ rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()), input);
+ }
+ return input;
+}
+
/// Converts the scaled MFMA/WMMA operands, `scalesA` and `scalesB`, from MLIR
/// AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
///
@@ -1171,6 +1192,105 @@ static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType,
return std::nullopt;
}
+/// Returns the `rocdl` intrinsic corresponding to a SparseMFMA (smfmac)
+/// operation if one exists. This includes checking to ensure the intrinsic is
+/// supported on the architecture you are compiling for.
+static std::optional<StringRef> smfmacOpToIntrinsic(SparseMFMAOp op,
+ Chipset chipset) {
+ bool isGfx950 = chipset >= kGfx950;
+ auto isFp8 = [&](Type t) { return typeIsExpectedFp8ForChipset(chipset, t); };
+ auto isBf8 = [&](Type t) { return typeIsExpectedBf8ForChipset(chipset, t); };
+
+ uint32_t m = op.getM(), n = op.getN(), k = op.getK();
+ Type sourceAElem = getElementTypeOrSelf(op.getSourceA().getType());
+ Type sourceBElem = getElementTypeOrSelf(op.getSourceB().getType());
+ Type destElem = getElementTypeOrSelf(op.getDestC().getType());
+
+ if (m == 16 && n == 16 && k == 32) {
+ if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
+ return ROCDL::smfmac_f32_16x16x32_f16::getOperationName();
+ if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
+ return ROCDL::smfmac_f32_16x16x32_bf16::getOperationName();
+ }
+
+ if (m == 16 && n == 16 && k == 64) {
+ if (isGfx950) {
+ if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
+ return ROCDL::smfmac_f32_16x16x64_f16::getOperationName();
+ if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
+ return ROCDL::smfmac_f32_16x16x64_bf16::getOperationName();
+ }
+ if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
+ destElem.isInteger(32))
+ return ROCDL::smfmac_i32_16x16x64_i8::getOperationName();
+ if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_16x16x64_fp8_fp8::getOperationName();
+ if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_16x16x64_fp8_bf8::getOperationName();
+ if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_16x16x64_bf8_fp8::getOperationName();
+ if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_16x16x64_bf8_bf8::getOperationName();
+ }
+
+ if (m == 16 && n == 16 && k == 128 && isGfx950) {
+ if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
+ destElem.isInteger(32))
+ return ROCDL::smfmac_i32_16x16x128_i8::getOperationName();
+ if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_16x16x128_fp8_fp8::getOperationName();
+ if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_16x16x128_fp8_bf8::getOperationName();
+ if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_16x16x128_bf8_fp8::getOperationName();
+ if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_16x16x128_bf8_bf8::getOperationName();
+ }
+
+ if (m == 32 && n == 32 && k == 16) {
+ if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
+ return ROCDL::smfmac_f32_32x32x16_f16::getOperationName();
+ if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
+ return ROCDL::smfmac_f32_32x32x16_bf16::getOperationName();
+ }
+
+ if (m == 32 && n == 32 && k == 32) {
+ if (isGfx950) {
+ if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
+ return ROCDL::smfmac_f32_32x32x32_f16::getOperationName();
+ if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
+ return ROCDL::smfmac_f32_32x32x32_bf16::getOperationName();
+ }
+ if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
+ destElem.isInteger(32))
+ return ROCDL::smfmac_i32_32x32x32_i8::getOperationName();
+ if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_32x32x32_fp8_fp8::getOperationName();
+ if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_32x32x32_fp8_bf8::getOperationName();
+ if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_32x32x32_bf8_fp8::getOperationName();
+ if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_32x32x32_bf8_bf8::getOperationName();
+ }
+
+ if (m == 32 && n == 32 && k == 64 && isGfx950) {
+ if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
+ destElem.isInteger(32))
+ return ROCDL::smfmac_i32_32x32x64_i8::getOperationName();
+ if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_32x32x64_fp8_fp8::getOperationName();
+ if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_32x32x64_fp8_bf8::getOperationName();
+ if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_32x32x64_bf8_fp8::getOperationName();
+ if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
+ return ROCDL::smfmac_f32_32x32x64_bf8_bf8::getOperationName();
+ }
+
+ return std::nullopt;
+}
+
/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
/// if one exists. This includes checking to ensure the intrinsic is supported
/// on the architecture you are compiling for.
@@ -1326,6 +1446,52 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
}
};
+struct SparseMFMAOpLowering : public ConvertOpToLLVMPattern<SparseMFMAOp> {
+ SparseMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<SparseMFMAOp>(converter), chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(SparseMFMAOp op, SparseMFMAOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto outType =
+ typeConverter->convertType<VectorType>(op.getDestC().getType());
+ if (!outType)
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+ // smfmac is supported on gfx942 and gfx950.
+ if (chipset.majorVersion != 9 || chipset < kGfx942)
+ return op->emitOpError("sparse MFMA (smfmac) only supported on gfx942+");
+ bool isGfx950 = chipset >= kGfx950;
+
+ Value a = convertSparseMFMAVectorOperand(rewriter, loc,
+ adaptor.getSourceA(), isGfx950);
+ Value b = convertSparseMFMAVectorOperand(rewriter, loc,
+ adaptor.getSourceB(), isGfx950);
+ Value c = adaptor.getDestC();
+
+ std::optional<StringRef> maybeIntrinsic = smfmacOpToIntrinsic(op, chipset);
+ if (!maybeIntrinsic.has_value())
+ return op.emitOpError(
+ "no intrinsic matching sparse MFMA on the given chipset");
+
+ // Bitcast sparse indices from vector<4xi8> or vector<2xi16> to i32.
+ Value sparseIdx = LLVM::BitcastOp::create(
+ rewriter, loc, rewriter.getI32Type(), adaptor.getSparseIdx());
+
+ OperationState loweredOp(loc, maybeIntrinsic.value());
+ loweredOp.addTypes(outType);
+ loweredOp.addOperands({a, b, c, sparseIdx,
+ createI32Constant(rewriter, loc, op.getCbsz()),
+ createI32Constant(rewriter, loc, op.getAbid())});
+ Value lowered = rewriter.create(loweredOp)->getResult(0);
+ rewriter.replaceOp(op, lowered);
+ return success();
+ }
+};
+
struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
@@ -3376,12 +3542,12 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
ROCDL::RawPtrBufferAtomicCmpSwap>,
AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
- WMMAOpLowering, ScaledWMMAOpLowering, ExtPackedFp8OpLowering,
- ScaledExtPackedMatrixOpLowering, ScaledExtPackedOpLowering,
- PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
- PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
- TransposeLoadOpLowering, AMDGPUPermlaneLowering,
- AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
+ SparseMFMAOpLowering, WMMAOpLowering, ScaledWMMAOpLowering,
+ ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering,
+ ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
+ PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
+ GatherToLDSOpLowering, TransposeLoadOpLowering,
+ AMDGPUPermlaneLowering, AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
AMDGPULowerDescriptor<MakeDmaDescriptorOp>,
AMDGPULowerDescriptor<MakeGatherDmaDescriptorOp>,
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index bef0328c7c73e..e77d131509add 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -632,6 +632,78 @@ LogicalResult MFMAOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// SparseMFMAOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult SparseMFMAOp::verify() {
+ constexpr uint32_t waveSize = 64;
+
+ auto sparseType = cast<VectorType>(getSourceA().getType());
+ auto denseType = cast<VectorType>(getSourceB().getType());
+ auto destType = cast<VectorType>(getDestC().getType());
+
+ Type sparseElem = sparseType.getElementType();
+ Type denseElem = denseType.getElementType();
+ int64_t sparseLen = sparseType.getNumElements();
+ int64_t denseLen = denseType.getNumElements();
+ int64_t destLen = destType.getNumElements();
+
+ if (denseLen != 2 * sparseLen)
+ return emitOpError("expected dense source operand to have exactly double "
+ "the number of elements of the sparse source operand");
+
+ // Check that source element types are compatible.
+ // For fp8/bf8 mixed operations, element types can
diff er (e.g., fp8 * bf8).
+ // For other types, element types must match exactly.
+ bool bothFloat8 = sparseElem.isFloat(8) && denseElem.isFloat(8);
+ if (!bothFloat8 && sparseElem != denseElem)
+ return emitOpError(
+ "expected source operands to have the same element type");
+
+ // When CBSZ == 0, ABID selects the index set within the sparse index VGPR.
+ // When CBSZ != 0, the first index set is always used (ABID ignored).
+ bool is8BitSource = sparseElem.isFloat(8) || sparseElem.isInteger(8);
+ // 8-bit source: ABID selects one of two 16-bit index sets.
+ if (getCbsz() == 0 && is8BitSource && getAbid() > 1)
+ return emitOpError("ABID must be 0 or 1 for 8-bit source data");
+ // 16-bit source: ABID selects one of four 8-bit index sets (0-3 all valid).
+ if (getCbsz() == 0 && !is8BitSource && getAbid() > 3)
+ return emitOpError("ABID must be between 0 and 3 for 16-bit source data");
+
+ // Validate sparseIdx type matches source element type.
+ auto sparseIdxType = cast<VectorType>(getSparseIdx().getType());
+ if (is8BitSource) {
+ // 8-bit source data requires vector<2xi16> sparse indices.
+ if (sparseIdxType.getNumElements() != 2 ||
+ !sparseIdxType.getElementType().isInteger(16))
+ return emitOpError("expected vector<2xi16> sparse indices for 8-bit "
+ "source data, but got ")
+ << getSparseIdx().getType();
+ } else {
+ // 16-bit source data requires vector<4xi8> sparse indices.
+ if (sparseIdxType.getNumElements() != 4 ||
+ !sparseIdxType.getElementType().isInteger(8))
+ return emitOpError("expected vector<4xi8> sparse indices for 16-bit "
+ "source data, but got ")
+ << getSparseIdx().getType();
+ }
+
+ int64_t expectedSourceElems = (getM() * getK()) / waveSize;
+ if (denseLen != expectedSourceElems)
+ return emitOpError("expected " + Twine(expectedSourceElems) +
+ " source values for this operation but got " +
+ Twine(denseLen));
+
+ int64_t expectedDestElems = (getM() * getN()) / waveSize;
+ if (destLen != expectedDestElems)
+ return emitOpError("expected " + Twine(expectedDestElems) +
+ " result values for this operation but got " +
+ Twine(destLen));
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// DPPOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma-gfx950.mlir b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma-gfx950.mlir
new file mode 100644
index 0000000000000..266e0e7e15595
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma-gfx950.mlir
@@ -0,0 +1,61 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx950 -cse | FileCheck %s
+func.func @sparse_mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf16>,
+ %arg2 : vector<4xf32>, %arg3 : vector<16xf32>,
+ %arg4 : vector<8xbf16>, %arg5 : vector<16xbf16>,
+ %arg6 : vector<16xi8>, %arg7 : vector<32xi8>,
+ %arg8 : vector<4xi32>, %arg9 : vector<16xi32>,
+ %arg10 : vector<16xf8E4M3FN>, %arg11 : vector<16xf8E5M2>,
+ %arg12 : vector<32xf8E4M3FN>, %arg13 : vector<32xf8E5M2>,
+ %arg14 : vector<4xi8>, %arg15 : vector<2xi16>) {
+ // CHECK: llvm.bitcast %{{.*}} : vector<4xi8> to i32
+ // CHECK: rocdl.smfmac.f32.16x16x64.f16{{.*}}: (vector<8xf16>, vector<16xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.sparse_mfma 16x16x64 %arg0 * %arg1 + %arg2 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf16>, vector<16xf16>, vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x64.bf16{{.*}}: (vector<8xbf16>, vector<16xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.sparse_mfma 16x16x64 %arg4 * %arg5 + %arg2 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xbf16>, vector<16xbf16>, vector<4xf32>
+
+ // CHECK: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
+ // CHECK: llvm.bitcast {{.*}} : vector<32xi8> to vector<8xi32>
+ // CHECK: llvm.bitcast %{{.*}} : vector<2xi16> to i32
+ // CHECK: rocdl.smfmac.i32.16x16x128.i8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+ amdgpu.sparse_mfma 16x16x128 %arg6 * %arg7 + %arg8 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xi8>, vector<32xi8>, vector<4xi32>
+
+ // CHECK: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
+ // CHECK: llvm.bitcast {{.*}} : vector<32xi8> to vector<8xi32>
+ // CHECK: rocdl.smfmac.f32.16x16x128.fp8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.sparse_mfma 16x16x128 %arg10 * %arg12 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E4M3FN>, vector<4xf32>
+
+ // CHECK: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
+ // CHECK: llvm.bitcast {{.*}} : vector<32xi8> to vector<8xi32>
+ // CHECK: rocdl.smfmac.f32.16x16x128.bf8.bf8 {{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.sparse_mfma 16x16x128 %arg11 * %arg13 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E5M2>, vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x128.fp8.bf8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.sparse_mfma 16x16x128 %arg10 * %arg13 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E5M2>, vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x128.bf8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.sparse_mfma 16x16x128 %arg11 * %arg12 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E4M3FN>, vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x32.f16{{.*}}: (vector<8xf16>, vector<16xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.sparse_mfma 32x32x32 %arg0 * %arg1 + %arg3 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf16>, vector<16xf16>, vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x32.bf16{{.*}}: (vector<8xbf16>, vector<16xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.sparse_mfma 32x32x32 %arg4 * %arg5 + %arg3 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xbf16>, vector<16xbf16>, vector<16xf32>
+
+ // CHECK: rocdl.smfmac.i32.32x32x64.i8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
+ amdgpu.sparse_mfma 32x32x64 %arg6 * %arg7 + %arg9 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xi8>, vector<32xi8>, vector<16xi32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x64.fp8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.sparse_mfma 32x32x64 %arg10 * %arg12 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E4M3FN>, vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x64.bf8.bf8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.sparse_mfma 32x32x64 %arg11 * %arg13 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E5M2>, vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x64.fp8.bf8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.sparse_mfma 32x32x64 %arg10 * %arg13 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E5M2>, vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x64.bf8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.sparse_mfma 32x32x64 %arg11 * %arg12 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E4M3FN>, vector<16xf32>
+
+ func.return
+}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir
new file mode 100644
index 0000000000000..b2c91c3d9bed1
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir
@@ -0,0 +1,63 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx942 -cse | FileCheck %s
+func.func @sparse_mfma_to_rocdl(%arg0 : vector<4xf16>, %arg1 : vector<8xf16>,
+ %arg2 : vector<4xf32>, %arg3 : vector<16xf32>,
+ %arg4 : vector<4xbf16>, %arg5 : vector<8xbf16>,
+ %arg6 : vector<8xi8>, %arg7 : vector<16xi8>,
+ %arg8 : vector<4xi32>, %arg9 : vector<16xi32>,
+ %arg10 : vector<8xf8E4M3FNUZ>, %arg11 : vector<8xf8E5M2FNUZ>,
+ %arg12 : vector<16xf8E4M3FNUZ>, %arg13 : vector<16xf8E5M2FNUZ>,
+ %arg14 : vector<4xi8>, %arg15 : vector<2xi16>) {
+ // CHECK: llvm.bitcast %{{.*}} : vector<4xi8> to i32
+ // CHECK: rocdl.smfmac.f32.16x16x32.f16{{.*}}: (vector<4xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.sparse_mfma 16x16x32 %arg0 * %arg1 + %arg2 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xf16>, vector<8xf16>, vector<4xf32>
+
+ // CHECK: llvm.bitcast {{.*}} : vector<4xbf16> to vector<4xi16>
+ // CHECK: llvm.bitcast {{.*}} : vector<8xbf16> to vector<8xi16>
+ // CHECK: rocdl.smfmac.f32.16x16x32.bf16 {{.*}}: (vector<4xi16>, vector<8xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.sparse_mfma 16x16x32 %arg4 * %arg5 + %arg2 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xbf16>, vector<8xbf16>, vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x16.f16{{.*}}: (vector<4xf16>, vector<8xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.sparse_mfma 32x32x16 %arg0 * %arg1 + %arg3 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xf16>, vector<8xf16>, vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x16.bf16 {{.*}}: (vector<4xi16>, vector<8xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.sparse_mfma 32x32x16 %arg4 * %arg5 + %arg3 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xbf16>, vector<8xbf16>, vector<16xf32>
+
+ // CHECK: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32>
+ // CHECK: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
+ // CHECK: llvm.bitcast %{{.*}} : vector<2xi16> to i32
+ // CHECK: rocdl.smfmac.i32.16x16x64.i8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+ amdgpu.sparse_mfma 16x16x64 %arg6 * %arg7 + %arg8 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xi8>, vector<16xi8>, vector<4xi32>
+
+ // CHECK: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32>
+ // CHECK: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
+ // CHECK: rocdl.smfmac.f32.16x16x64.fp8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.sparse_mfma 16x16x64 %arg10 * %arg12 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E4M3FNUZ>, vector<4xf32>
+
+ // CHECK: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32>
+ // CHECK: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
+ // CHECK: rocdl.smfmac.f32.16x16x64.bf8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.sparse_mfma 16x16x64 %arg11 * %arg13 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E5M2FNUZ>, vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x64.fp8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.sparse_mfma 16x16x64 %arg10 * %arg13 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E5M2FNUZ>, vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x64.bf8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ amdgpu.sparse_mfma 16x16x64 %arg11 * %arg12 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E4M3FNUZ>, vector<4xf32>
+
+ // CHECK: rocdl.smfmac.i32.32x32x32.i8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
+ amdgpu.sparse_mfma 32x32x32 %arg6 * %arg7 + %arg9 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xi8>, vector<16xi8>, vector<16xi32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x32.fp8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.sparse_mfma 32x32x32 %arg10 * %arg12 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E4M3FNUZ>, vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x32.bf8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.sparse_mfma 32x32x32 %arg11 * %arg13 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E5M2FNUZ>, vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x32.fp8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.sparse_mfma 32x32x32 %arg10 * %arg13 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E5M2FNUZ>, vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x32.bf8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ amdgpu.sparse_mfma 32x32x32 %arg11 * %arg12 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E4M3FNUZ>, vector<16xf32>
+
+ func.return
+}
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 9ece57e9ec6a3..1299f3b14b14f 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -452,3 +452,67 @@ func.func @make_gather_dma_descriptor_invalid_index_types(%base: !amdgpu.tdm_gat
amdgpu.make_gather_dma_descriptor %base[%indices] globalSize [4, 4] globalStride [1, 1] sharedSize [1, 2] : !amdgpu.tdm_gather_base<i32, i16>, vector<8xi32> -> !amdgpu.tdm_descriptor
func.return
}
+
+// -----
+
+func.func @sparse_mfma_dense_not_double_sparse(%a: vector<4xf16>, %b: vector<4xf16>, %c: vector<4xf32>, %idx: vector<4xi8>) -> vector<4xf32> {
+ // expected-error at +1 {{'amdgpu.sparse_mfma' op operand #1 must be vector of 16-bit float values of length 8/16 or vector of bfloat16 type values of length 8/16 or vector of 8-bit signless integer values of length 16/32 or vector of f8E4M3FN type or f8E5M2 type values of length 16/32 or vector of f8E4M3FNUZ type or f8E5M2FNUZ type values of length 16/32, but got 'vector<4xf16>'}}
+ %d = amdgpu.sparse_mfma 16x16x32 %a * %b + %c sparse(%idx : vector<4xi8>) : vector<4xf16>, vector<4xf16>, vector<4xf32>
+ func.return %d : vector<4xf32>
+}
+
+// -----
+
+func.func @sparse_mfma_mismatched_source_types(%a: vector<4xf16>, %b: vector<8xbf16>, %c: vector<4xf32>, %idx: vector<4xi8>) -> vector<4xf32> {
+ // expected-error at +1 {{'amdgpu.sparse_mfma' op expected source operands to have the same element type}}
+ %d = amdgpu.sparse_mfma 16x16x32 %a * %b + %c sparse(%idx : vector<4xi8>) : vector<4xf16>, vector<8xbf16>, vector<4xf32>
+ func.return %d : vector<4xf32>
+}
+
+// -----
+
+func.func @sparse_mfma_abid_invalid_for_8bit(%a: vector<8xi8>, %b: vector<16xi8>, %c: vector<4xi32>, %idx: vector<2xi16>) -> vector<4xi32> {
+ // expected-error at +1 {{'amdgpu.sparse_mfma' op ABID must be 0 or 1 for 8-bit source data}}
+ %d = amdgpu.sparse_mfma 16x16x64 %a * %b + %c sparse(%idx : vector<2xi16>) { abid = 2 : i32, cbsz = 0 : i32 } : vector<8xi8>, vector<16xi8>, vector<4xi32>
+ func.return %d : vector<4xi32>
+}
+
+// -----
+
+func.func @sparse_mfma_abid_invalid_for_16bit(%a: vector<4xf16>, %b: vector<8xf16>, %c: vector<4xf32>, %idx: vector<4xi8>) -> vector<4xf32> {
+ // expected-error at +1 {{'amdgpu.sparse_mfma' op ABID must be between 0 and 3 for 16-bit source data}}
+ %d = amdgpu.sparse_mfma 16x16x32 %a * %b + %c sparse(%idx : vector<4xi8>) { abid = 4 : i32, cbsz = 0 : i32 } : vector<4xf16>, vector<8xf16>, vector<4xf32>
+ func.return %d : vector<4xf32>
+}
+
+// -----
+
+func.func @sparse_mfma_wrong_idx_type_for_8bit(%a: vector<8xi8>, %b: vector<16xi8>, %c: vector<4xi32>, %idx: vector<4xi8>) -> vector<4xi32> {
+ // expected-error at +1 {{'amdgpu.sparse_mfma' op expected vector<2xi16> sparse indices for 8-bit source data, but got 'vector<4xi8>'}}
+ %d = amdgpu.sparse_mfma 16x16x64 %a * %b + %c sparse(%idx : vector<4xi8>) : vector<8xi8>, vector<16xi8>, vector<4xi32>
+ func.return %d : vector<4xi32>
+}
+
+// -----
+
+func.func @sparse_mfma_wrong_idx_type_for_16bit(%a: vector<4xf16>, %b: vector<8xf16>, %c: vector<4xf32>, %idx: vector<2xi16>) -> vector<4xf32> {
+ // expected-error at +1 {{'amdgpu.sparse_mfma' op expected vector<4xi8> sparse indices for 16-bit source data, but got 'vector<2xi16>'}}
+ %d = amdgpu.sparse_mfma 16x16x32 %a * %b + %c sparse(%idx : vector<2xi16>) : vector<4xf16>, vector<8xf16>, vector<4xf32>
+ func.return %d : vector<4xf32>
+}
+
+// -----
+
+func.func @sparse_mfma_wrong_source_count(%a: vector<4xf16>, %b: vector<8xf16>, %c: vector<16xf32>, %idx: vector<4xi8>) -> vector<16xf32> {
+ // expected-error at +1 {{'amdgpu.sparse_mfma' op expected 16 source values for this operation but got 8}}
+ %d = amdgpu.sparse_mfma 32x32x32 %a * %b + %c sparse(%idx : vector<4xi8>) : vector<4xf16>, vector<8xf16>, vector<16xf32>
+ func.return %d : vector<16xf32>
+}
+
+// -----
+
+func.func @sparse_mfma_wrong_dest_count(%a: vector<4xf16>, %b: vector<8xf16>, %c: vector<16xf32>, %idx: vector<4xi8>) -> vector<16xf32> {
+ // expected-error at +1 {{'amdgpu.sparse_mfma' op expected 4 result values for this operation but got 16}}
+ %d = amdgpu.sparse_mfma 16x16x32 %a * %b + %c sparse(%idx : vector<4xi8>) : vector<4xf16>, vector<8xf16>, vector<16xf32>
+ func.return %d : vector<16xf32>
+}
More information about the Mlir-commits
mailing list