[Mlir-commits] [mlir] [mlir][amdgpu] implement amdgpu.sparse_mfma wrapper for smfmac instructions (PR #171968)

Eric Feng llvmlistbot at llvm.org
Thu Dec 11 22:35:41 PST 2025


https://github.com/efric created https://github.com/llvm/llvm-project/pull/171968

None

>From 73caf01662dba501bb1c887bdffa0eef6a3ba678 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Thu, 11 Dec 2025 22:18:17 -0800
Subject: [PATCH 1/2] implement amdgpu wrapper for smfmac

Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td |  81 ++++++++
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 176 +++++++++++++++++-
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  |  56 ++++++
 .../AMDGPUToROCDL/sparse-mfma-gfx950.mlir     |  53 ++++++
 .../Conversion/AMDGPUToROCDL/sparse-mfma.mlir |  61 ++++++
 5 files changed, 422 insertions(+), 5 deletions(-)
 create mode 100644 mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma-gfx950.mlir
 create mode 100644 mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 56160d3e8fe85..9b4947049c388 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -958,6 +958,27 @@ def MFMAOutTypes : AnyTypeOf<[F64,
                               VectorOfLengthAndType<[4, 16, 32], [F32]>,
                               VectorOfLengthAndType<[4, 16, 32], [I32]>,
                               VectorOfLengthAndType<[4], [F64]>]>;
+
+// sparse_mfma (smfmac)
+def SMFMACSparseInTypes : AnyTypeOf<[
+    VectorOfLengthAndType<[4, 8], [F16]>,
+    VectorOfLengthAndType<[4, 8], [BF16]>,
+    VectorOfLengthAndType<[8, 16], [I8]>,
+    VectorOfLengthAndType<[8, 16], [F8E4M3FN, F8E5M2]>
+]>;
+
+def SMFMACDenseInTypes : AnyTypeOf<[
+    VectorOfLengthAndType<[8, 16], [F16]>,
+    VectorOfLengthAndType<[8, 16], [BF16]>,
+    VectorOfLengthAndType<[16, 32], [I8]>,
+    VectorOfLengthAndType<[16, 32], [F8E4M3FN, F8E5M2]>
+]>;
+
+def SMFMACOutTypes : AnyTypeOf<[
+    VectorOfLengthAndType<[4, 16], [F32]>,
+    VectorOfLengthAndType<[4, 16], [I32]>
+]>;
+
 // scaled_mfma
 def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN]>,
                                    VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
@@ -1097,6 +1118,66 @@ def AMDGPU_WMMAOp :
   let hasVerifier = 1;
 }
 
+def AMDGPU_SparseMFMAOp :
+    AMDGPU_Op<"sparse_mfma", [AllTypesMatch<["destC", "destD"]>,
+                              Pure]>,
+    Arguments<(ins
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$m,
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$n,
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32, 64, 128]>]>:$k,
+                   SMFMACSparseInTypes:$sourceA,
+                   SMFMACDenseInTypes:$sourceB,
+                   SMFMACOutTypes:$destC,
+                   I32:$sparseIdx,
+                   DefaultValuedAttr<I32Attr, "0">:$cbsz,
+                   DefaultValuedAttr<I32Attr, "0">:$abid)>,
+    Results<(outs SMFMACOutTypes: $destD)> {
+  let summary = "MLIR wrapper for CDNA sparse mfma (smfmac) instructions";
+  let description = [{
+    The `amdgpu.sparse_mfma` op is an MLIR wrapper around intrinsics for various
+    `smfmac` instructions in the AMDGPU architecture, which perform matrix
+    multiply-accumulate operations using 2:4 structured sparsity on matrix A
+    with dense matrices B, C, and D.
+
+    On gfx940, smfmac intrinsics support:
+      - M=N=16, K=32 and M=N=32, K=16 for f16 and bf16 sources
+      - M=N=16, K=64 and M=N=32, K=32 for i8 and fp8 sources
+
+    On gfx950, smfmac intrinsics additionally support:
+      - M=N=16, K=64 and M=N=32, K=32 for f16 and bf16 sources
+      - M=N=16, K=128 and M=N=32, K=64 for i8 and fp8 sources
+
+    The `sparseIdx` parameter (i32) contains packed indices identifying the
+    positions of non-zero elements in the 2:4 sparse matrix A. For 16-bit data,
+    this uses four groups of 8-bit indices; for 8-bit data, 2 groups of 16-bit
+    indices.
+
+    The `cbsz` and `abid` parameters are repurposed to select the index set.
+    If `cbsz == 0`, then `abid[1:0]` selects which index set to use. 
+    If `cbsz != 0`, then the very first is selected.
+
+    Example:
+    ```mlir
+      %0 = amdgpu.sparse_mfma 16x16x32 %matA * %matB + %matC sparse(%idx)
+        : vector<4xf16>, vector<8xf16>, vector<4xf32>
+
+      %1 = amdgpu.sparse_mfma 16x16x64 %matA * %matB + %matC sparse(%idx)
+        : vector<8xf16>, vector<16xf16>, vector<4xf32>
+
+      %2 = amdgpu.sparse_mfma 16x16x128 %matA * %matB + %matC sparse(%idx)
+        { cbsz = 0 : i32, abid = 1 : i32 }
+        : vector<4xi32>, vector<8xi32>, vector<4xi32>
+    ```
+  }];
+  let assemblyFormat = [{
+    custom<MNKDimensionList>($m, $n, $k) $sourceA `*` $sourceB `+` $destC
+    `sparse` `(` $sparseIdx `)`
+    attr-dict
+    `:` type($sourceA) `,` type($sourceB) `,` type($destC)
+  }];
+  let hasVerifier = 1;
+}
+
 def AMDGPU_GatherToLDSOp :
     AMDGPU_Op<"gather_to_lds", [AttrSizedOperandSegments]>,
     Arguments<(ins
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 4b1509392aa6f..855d0c9df4281 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -661,6 +661,30 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
   return input;
 }
 
+/// Converts sparse MFMA (smfmac) operands to the expected ROCDL types.
+static Value convertSparseMFMAVectorOperand(ConversionPatternRewriter &rewriter,
+                                            Location loc, Value input,
+                                            bool allowBf16 = true) {
+  Type inputType = input.getType();
+  if (auto vectorType = dyn_cast<VectorType>(inputType)) {
+    // bf16 -> i16 when not allowed (pre-gfx950)
+    if (vectorType.getElementType().isBF16() && !allowBf16)
+      return LLVM::BitcastOp::create(
+          rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
+    // i8/fp8 vectors -> vector<Nxi32>
+    if (isa<IntegerType>(vectorType.getElementType()) &&
+        vectorType.getElementTypeBitWidth() <= 8) {
+      int64_t numWords = llvm::divideCeil(
+          vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
+          32);
+      return LLVM::BitcastOp::create(
+          rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
+          input);
+    }
+  }
+  return input;
+}
+
 /// Converts the scaled MFMA operands, `scalesA` and `scalesB`, from MLIR AMDGPU
 /// dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
 ///
@@ -1136,6 +1160,104 @@ static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType,
   return std::nullopt;
 }
 
+/// Returns the `rocdl` intrinsic corresponding to a SparseMFMA (smfmac)
+/// operation if one exists. This includes checking to ensure the intrinsic is
+/// supported on the architecture you are compiling for.
+static std::optional<StringRef> smfmacOpToIntrinsic(SparseMFMAOp op,
+                                                    bool isGfx950) {
+  using fp8 = Float8E4M3FNType;
+  using bf8 = Float8E5M2Type;
+
+  uint32_t m = op.getM(), n = op.getN(), k = op.getK();
+  Type sourceAElem = getElementTypeOrSelf(op.getSourceA().getType());
+  Type sourceBElem = getElementTypeOrSelf(op.getSourceB().getType());
+  Type destElem = getElementTypeOrSelf(op.getDestC().getType());
+
+  if (m == 16 && n == 16 && k == 32) {
+    if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
+      return ROCDL::smfmac_f32_16x16x32_f16::getOperationName();
+    if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
+      return ROCDL::smfmac_f32_16x16x32_bf16::getOperationName();
+  }
+
+  if (m == 16 && n == 16 && k == 64) {
+    if (isGfx950) {
+      if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
+        return ROCDL::smfmac_f32_16x16x64_f16::getOperationName();
+      if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
+        return ROCDL::smfmac_f32_16x16x64_bf16::getOperationName();
+    }
+    if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
+        destElem.isInteger(32))
+      return ROCDL::smfmac_i32_16x16x64_i8::getOperationName();
+    if (isa<fp8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_16x16x64_fp8_fp8::getOperationName();
+    if (isa<fp8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_16x16x64_fp8_bf8::getOperationName();
+    if (isa<bf8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_16x16x64_bf8_fp8::getOperationName();
+    if (isa<bf8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_16x16x64_bf8_bf8::getOperationName();
+  }
+
+  if (m == 16 && n == 16 && k == 128 && isGfx950) {
+    if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
+        destElem.isInteger(32))
+      return ROCDL::smfmac_i32_16x16x128_i8::getOperationName();
+    if (isa<fp8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_16x16x128_fp8_fp8::getOperationName();
+    if (isa<fp8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_16x16x128_fp8_bf8::getOperationName();
+    if (isa<bf8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_16x16x128_bf8_fp8::getOperationName();
+    if (isa<bf8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_16x16x128_bf8_bf8::getOperationName();
+  }
+
+  if (m == 32 && n == 32 && k == 16) {
+    if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
+      return ROCDL::smfmac_f32_32x32x16_f16::getOperationName();
+    if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
+      return ROCDL::smfmac_f32_32x32x16_bf16::getOperationName();
+  }
+
+  if (m == 32 && n == 32 && k == 32) {
+    if (isGfx950) {
+      if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
+        return ROCDL::smfmac_f32_32x32x32_f16::getOperationName();
+      if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
+        return ROCDL::smfmac_f32_32x32x32_bf16::getOperationName();
+    }
+    if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
+        destElem.isInteger(32))
+      return ROCDL::smfmac_i32_32x32x32_i8::getOperationName();
+    if (isa<fp8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_32x32x32_fp8_fp8::getOperationName();
+    if (isa<fp8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_32x32x32_fp8_bf8::getOperationName();
+    if (isa<bf8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_32x32x32_bf8_fp8::getOperationName();
+    if (isa<bf8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_32x32x32_bf8_bf8::getOperationName();
+  }
+
+  if (m == 32 && n == 32 && k == 64 && isGfx950) {
+    if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
+        destElem.isInteger(32))
+      return ROCDL::smfmac_i32_32x32x64_i8::getOperationName();
+    if (isa<fp8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_32x32x64_fp8_fp8::getOperationName();
+    if (isa<fp8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_32x32x64_fp8_bf8::getOperationName();
+    if (isa<bf8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_32x32x64_bf8_fp8::getOperationName();
+    if (isa<bf8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_32x32x64_bf8_bf8::getOperationName();
+  }
+
+  return std::nullopt;
+}
+
 /// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
 /// if one exists. This includes checking to ensure the intrinsic is supported
 /// on the architecture you are compiling for.
@@ -1291,6 +1413,49 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
   }
 };
 
+struct SparseMFMAOpLowering : public ConvertOpToLLVMPattern<SparseMFMAOp> {
+  SparseMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<SparseMFMAOp>(converter), chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(SparseMFMAOp op, SparseMFMAOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto outType =
+        typeConverter->convertType<VectorType>(op.getDestC().getType());
+    if (!outType)
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+    // smfmac is supported on gfx942 and gfx950
+    if (chipset.majorVersion != 9 || chipset < kGfx942)
+      return op->emitOpError("sparse MFMA (smfmac) only supported on gfx942+");
+    bool isGfx950 = chipset >= kGfx950;
+
+    Value a = convertSparseMFMAVectorOperand(rewriter, loc,
+                                             adaptor.getSourceA(), isGfx950);
+    Value b = convertSparseMFMAVectorOperand(rewriter, loc,
+                                             adaptor.getSourceB(), isGfx950);
+    Value c = adaptor.getDestC();
+
+    std::optional<StringRef> maybeIntrinsic = smfmacOpToIntrinsic(op, isGfx950);
+
+    if (!maybeIntrinsic.has_value())
+      return op.emitOpError(
+          "no intrinsic matching sparse MFMA on the given chipset");
+
+    OperationState loweredOp(loc, maybeIntrinsic.value());
+    loweredOp.addTypes(outType);
+    loweredOp.addOperands({a, b, c, adaptor.getSparseIdx(),
+                           createI32Constant(rewriter, loc, op.getCbsz()),
+                           createI32Constant(rewriter, loc, op.getAbid())});
+    Value lowered = rewriter.create(loweredOp)->getResult(0);
+    rewriter.replaceOp(op, lowered);
+    return success();
+  }
+};
+
 struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
       : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
@@ -2797,11 +2962,12 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
       RawBufferOpLowering<RawBufferAtomicCmpswapOp,
                           ROCDL::RawPtrBufferAtomicCmpSwap>,
       AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
-      SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
-      WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering,
-      ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
-      PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
-      GatherToLDSOpLowering, TransposeLoadOpLowering, AMDGPUPermlaneLowering,
+      SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering, 
+      WMMAOpLowering, SparseMFMAOpLowering, ExtPackedFp8OpLowering,
+      ScaledExtPackedMatrixOpLowering, ScaledExtPackedOpLowering,
+      PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
+      PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
+      TransposeLoadOpLowering, AMDGPUPermlaneLowering,
       AMDGPUMakeDmaBaseLowering, AMDGPUMakeDmaDescriptorLowering>(converter,
                                                                   chipset);
   patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index b7a665b0f5367..2cc1aaa8e3b2d 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -522,6 +522,62 @@ LogicalResult MFMAOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// SparseMFMAOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult SparseMFMAOp::verify() {
+  constexpr uint32_t waveSize = 64;
+
+  auto sparseType = cast<VectorType>(getSourceA().getType());
+  auto denseType = cast<VectorType>(getSourceB().getType());
+  auto destType = cast<VectorType>(getDestC().getType());
+
+  Type sparseElem = sparseType.getElementType();
+  Type denseElem = denseType.getElementType();
+  int64_t sparseLen = sparseType.getNumElements();
+  int64_t denseLen = denseType.getNumElements();
+  int64_t destLen = destType.getNumElements();
+
+  if (denseLen != 2 * sparseLen)
+    return emitOpError("expected dense source operand to have exactly double "
+                       "the number of elements of the sparse source operand");
+
+  // Check that source element types are compatible.
+  // For fp8/bf8 mixed operations, element types can differ (e.g., fp8 * bf8).
+  // For other types, element types must match exactly.
+  bool bothFloat8 = sparseElem.isFloat(8) && denseElem.isFloat(8);
+  if (!bothFloat8 && sparseElem != denseElem)
+    return emitOpError(
+        "expected source operands to have the same element type");
+
+  // When CBSZ == 0, ABID selects the index set within the sparse index VGPR.
+  // When CBSZ != 0, the first index set is always used (ABID ignored).
+  bool is8BitSource = sparseElem.isFloat(8) || sparseElem.isInteger(8);
+  if (getCbsz() == 0 && is8BitSource) {
+    // 8-bit source: ABID[0] selects one of two 16-bit index sets.
+    if (getAbid() > 1)
+      return emitOpError(
+          "ABID must be 0 or 1 for 8-bit source data when CBSZ is 0");
+  }
+  // 16-bit source: ABID[1:0] selects one of four 8-bit index sets (0-3 all
+  // valid).
+
+  int64_t expectedSourceElems = (getM() * getK()) / waveSize;
+  if (denseLen != expectedSourceElems)
+    return emitOpError("expected " + Twine(expectedSourceElems) +
+                       " source values for this operation but got " +
+                       Twine(denseLen));
+
+  int64_t expectedDestElems = (getM() * getN()) / waveSize;
+  if (destLen != expectedDestElems)
+    return emitOpError("expected " + Twine(expectedDestElems) +
+                       " result values for this operation but got " +
+                       Twine(destLen));
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // DPPOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma-gfx950.mlir b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma-gfx950.mlir
new file mode 100644
index 0000000000000..abe2565f7c41b
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma-gfx950.mlir
@@ -0,0 +1,53 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx950 -cse | FileCheck %s
+func.func @sparse_mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf16>,
+                                %arg2 : vector<4xf32>, %arg3 : vector<16xf32>,
+                                %arg4 : vector<8xbf16>, %arg5 : vector<16xbf16>,
+                                %arg6 : vector<16xi8>, %arg7 : vector<32xi8>,
+                                %arg8 : vector<4xi32>, %arg9 : vector<16xi32>,
+                                %arg10 : vector<16xf8E4M3FN>, %arg11 : vector<16xf8E5M2>,
+                                %arg12 : vector<32xf8E4M3FN>, %arg13 : vector<32xf8E5M2>,
+                                %arg14 : i32) {
+  // CHECK: rocdl.smfmac.f32.16x16x64.f16{{.*}}: (vector<8xf16>, vector<16xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.sparse_mfma 16x16x64 %arg0 * %arg1 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf16>, vector<16xf16>, vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x64.bf16{{.*}}: (vector<8xbf16>, vector<16xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.sparse_mfma 16x16x64 %arg4 * %arg5 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xbf16>, vector<16xbf16>, vector<4xf32>
+
+  // CHECK: rocdl.smfmac.i32.16x16x128.i8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+  amdgpu.sparse_mfma 16x16x128 %arg6 * %arg7 + %arg8 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xi8>, vector<32xi8>, vector<4xi32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x128.fp8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.sparse_mfma 16x16x128 %arg10 * %arg12 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E4M3FN>, vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x128.bf8.bf8 {{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.sparse_mfma 16x16x128 %arg11 * %arg13 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E5M2>, vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x128.fp8.bf8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.sparse_mfma 16x16x128 %arg10 * %arg13 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E5M2>, vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x128.bf8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.sparse_mfma 16x16x128 %arg11 * %arg12 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E4M3FN>, vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x32.f16{{.*}}: (vector<8xf16>, vector<16xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.sparse_mfma 32x32x32 %arg0 * %arg1 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf16>, vector<16xf16>, vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x32.bf16{{.*}}: (vector<8xbf16>, vector<16xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.sparse_mfma 32x32x32 %arg4 * %arg5 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xbf16>, vector<16xbf16>, vector<16xf32>
+
+  // CHECK: rocdl.smfmac.i32.32x32x64.i8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
+  amdgpu.sparse_mfma 32x32x64 %arg6 * %arg7 + %arg9 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xi8>, vector<32xi8>, vector<16xi32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x64.fp8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.sparse_mfma 32x32x64 %arg10 * %arg12 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E4M3FN>, vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x64.bf8.bf8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.sparse_mfma 32x32x64 %arg11 * %arg13 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E5M2>, vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x64.fp8.bf8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.sparse_mfma 32x32x64 %arg10 * %arg13 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E5M2>, vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x64.bf8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.sparse_mfma 32x32x64 %arg11 * %arg12 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E4M3FN>, vector<16xf32>
+
+  func.return
+}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir
new file mode 100644
index 0000000000000..65a0cd3f1f87f
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir
@@ -0,0 +1,61 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx942 -cse | FileCheck %s
+func.func @sparse_mfma_to_rocdl(%arg0 : vector<4xf16>, %arg1 : vector<8xf16>,
+                                %arg2 : vector<4xf32>, %arg3 : vector<16xf32>,
+                                %arg4 : vector<4xbf16>, %arg5 : vector<8xbf16>,
+                                %arg6 : vector<8xi8>, %arg7 : vector<16xi8>,
+                                %arg8 : vector<4xi32>, %arg9 : vector<16xi32>,
+                                %arg10 : vector<8xf8E4M3FN>, %arg11 : vector<8xf8E5M2>,
+                                %arg12 : vector<16xf8E4M3FN>, %arg13 : vector<16xf8E5M2>,
+                                %arg14 : i32) {
+  // CHECK: rocdl.smfmac.f32.16x16x32.f16{{.*}}: (vector<4xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.sparse_mfma 16x16x32 %arg0 * %arg1 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xf16>, vector<8xf16>, vector<4xf32>
+
+  // CHECK-NEXT: llvm.bitcast {{.*}} : vector<4xbf16> to vector<4xi16>
+  // CHECK-NEXT: llvm.bitcast {{.*}} : vector<8xbf16> to vector<8xi16>
+  // CHECK: rocdl.smfmac.f32.16x16x32.bf16 {{.*}}: (vector<4xi16>, vector<8xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.sparse_mfma 16x16x32 %arg4 * %arg5 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xbf16>, vector<8xbf16>, vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x16.f16{{.*}}: (vector<4xf16>, vector<8xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.sparse_mfma 32x32x16 %arg0 * %arg1 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xf16>, vector<8xf16>, vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x16.bf16 {{.*}}: (vector<4xi16>, vector<8xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.sparse_mfma 32x32x16 %arg4 * %arg5 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xbf16>, vector<8xbf16>, vector<16xf32>
+
+  // CHECK-NEXT: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32>
+  // CHECK-NEXT: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
+  // CHECK: rocdl.smfmac.i32.16x16x64.i8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+  amdgpu.sparse_mfma 16x16x64 %arg6 * %arg7 + %arg8 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xi8>, vector<16xi8>, vector<4xi32>
+
+  // CHECK-NEXT: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32>
+  // CHECK-NEXT: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
+  // CHECK: rocdl.smfmac.f32.16x16x64.fp8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.sparse_mfma 16x16x64 %arg10 * %arg12 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FN>, vector<16xf8E4M3FN>, vector<4xf32>
+
+  // CHECK-NEXT: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32>
+  // CHECK-NEXT: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
+  // CHECK: rocdl.smfmac.f32.16x16x64.bf8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.sparse_mfma 16x16x64 %arg11 * %arg13 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2>, vector<16xf8E5M2>, vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x64.fp8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.sparse_mfma 16x16x64 %arg10 * %arg13 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FN>, vector<16xf8E5M2>, vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x64.bf8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.sparse_mfma 16x16x64 %arg11 * %arg12 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2>, vector<16xf8E4M3FN>, vector<4xf32>
+
+  // CHECK: rocdl.smfmac.i32.32x32x32.i8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
+  amdgpu.sparse_mfma 32x32x32 %arg6 * %arg7 + %arg9 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xi8>, vector<16xi8>, vector<16xi32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x32.fp8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.sparse_mfma 32x32x32 %arg10 * %arg12 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FN>, vector<16xf8E4M3FN>, vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x32.bf8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.sparse_mfma 32x32x32 %arg11 * %arg13 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2>, vector<16xf8E5M2>, vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x32.fp8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.sparse_mfma 32x32x32 %arg10 * %arg13 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FN>, vector<16xf8E5M2>, vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x32.bf8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.sparse_mfma 32x32x32 %arg11 * %arg12 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2>, vector<16xf8E4M3FN>, vector<16xf32>
+
+  func.return
+}

>From e05394559965b8fdb6001313d651242683d049ed Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Thu, 11 Dec 2025 22:30:42 -0800
Subject: [PATCH 2/2] nits

Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 14 ++++++--------
 1 file changed, 6 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 2cc1aaa8e3b2d..4231014c77982 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -554,14 +554,12 @@ LogicalResult SparseMFMAOp::verify() {
   // When CBSZ == 0, ABID selects the index set within the sparse index VGPR.
   // When CBSZ != 0, the first index set is always used (ABID ignored).
   bool is8BitSource = sparseElem.isFloat(8) || sparseElem.isInteger(8);
-  if (getCbsz() == 0 && is8BitSource) {
-    // 8-bit source: ABID[0] selects one of two 16-bit index sets.
-    if (getAbid() > 1)
-      return emitOpError(
-          "ABID must be 0 or 1 for 8-bit source data when CBSZ is 0");
-  }
-  // 16-bit source: ABID[1:0] selects one of four 8-bit index sets (0-3 all
-  // valid).
+  // 8-bit source: ABID selects one of two 16-bit index sets.
+  if (getCbsz() == 0 && is8BitSource && getAbid() > 1)
+    return emitOpError("ABID must be 0 or 1 for 8-bit source data");
+  // 16-bit source: ABID selects one of four 8-bit index sets (0-3 all valid).
+  if (getCbsz() == 0 && !is8BitSource && getAbid() > 3)
+    return emitOpError("ABID must be between 0 and 3 for 16-bit source data");
 
   int64_t expectedSourceElems = (getM() * getK()) / waveSize;
   if (denseLen != expectedSourceElems)



More information about the Mlir-commits mailing list