[Mlir-commits] [mlir] [mlir][amdgpu] implement amdgpu.sparse_mfma wrapper for smfmac instructions (PR #171968)

Eric Feng llvmlistbot at llvm.org
Thu Dec 18 14:35:13 PST 2025


https://github.com/efric updated https://github.com/llvm/llvm-project/pull/171968

>From 73caf01662dba501bb1c887bdffa0eef6a3ba678 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Thu, 11 Dec 2025 22:18:17 -0800
Subject: [PATCH 01/11] implement amdgpu wrapper for smfmac

Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td |  81 ++++++++
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 176 +++++++++++++++++-
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  |  56 ++++++
 .../AMDGPUToROCDL/sparse-mfma-gfx950.mlir     |  53 ++++++
 .../Conversion/AMDGPUToROCDL/sparse-mfma.mlir |  61 ++++++
 5 files changed, 422 insertions(+), 5 deletions(-)
 create mode 100644 mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma-gfx950.mlir
 create mode 100644 mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 56160d3e8fe85..9b4947049c388 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -958,6 +958,27 @@ def MFMAOutTypes : AnyTypeOf<[F64,
                               VectorOfLengthAndType<[4, 16, 32], [F32]>,
                               VectorOfLengthAndType<[4, 16, 32], [I32]>,
                               VectorOfLengthAndType<[4], [F64]>]>;
+
+// sparse_mfma (smfmac)
+def SMFMACSparseInTypes : AnyTypeOf<[
+    VectorOfLengthAndType<[4, 8], [F16]>,
+    VectorOfLengthAndType<[4, 8], [BF16]>,
+    VectorOfLengthAndType<[8, 16], [I8]>,
+    VectorOfLengthAndType<[8, 16], [F8E4M3FN, F8E5M2]>
+]>;
+
+def SMFMACDenseInTypes : AnyTypeOf<[
+    VectorOfLengthAndType<[8, 16], [F16]>,
+    VectorOfLengthAndType<[8, 16], [BF16]>,
+    VectorOfLengthAndType<[16, 32], [I8]>,
+    VectorOfLengthAndType<[16, 32], [F8E4M3FN, F8E5M2]>
+]>;
+
+def SMFMACOutTypes : AnyTypeOf<[
+    VectorOfLengthAndType<[4, 16], [F32]>,
+    VectorOfLengthAndType<[4, 16], [I32]>
+]>;
+
 // scaled_mfma
 def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN]>,
                                    VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
@@ -1097,6 +1118,66 @@ def AMDGPU_WMMAOp :
   let hasVerifier = 1;
 }
 
+def AMDGPU_SparseMFMAOp :
+    AMDGPU_Op<"sparse_mfma", [AllTypesMatch<["destC", "destD"]>,
+                              Pure]>,
+    Arguments<(ins
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$m,
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$n,
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32, 64, 128]>]>:$k,
+                   SMFMACSparseInTypes:$sourceA,
+                   SMFMACDenseInTypes:$sourceB,
+                   SMFMACOutTypes:$destC,
+                   I32:$sparseIdx,
+                   DefaultValuedAttr<I32Attr, "0">:$cbsz,
+                   DefaultValuedAttr<I32Attr, "0">:$abid)>,
+    Results<(outs SMFMACOutTypes: $destD)> {
+  let summary = "MLIR wrapper for CDNA sparse mfma (smfmac) instructions";
+  let description = [{
+    The `amdgpu.sparse_mfma` op is an MLIR wrapper around intrinsics for various
+    `smfmac` instructions in the AMDGPU architecture, which perform matrix
+    multiply-accumulate operations using 2:4 structured sparsity on matrix A
+    with dense matrices B, C, and D.
+
+    On gfx940, smfmac intrinsics support:
+      - M=N=16, K=32 and M=N=32, K=16 for f16 and bf16 sources
+      - M=N=16, K=64 and M=N=32, K=32 for i8 and fp8 sources
+
+    On gfx950, smfmac intrinsics additionally support:
+      - M=N=16, K=64 and M=N=32, K=32 for f16 and bf16 sources
+      - M=N=16, K=128 and M=N=32, K=64 for i8 and fp8 sources
+
+    The `sparseIdx` parameter (i32) contains packed indices identifying the
+    positions of non-zero elements in the 2:4 sparse matrix A. For 16-bit data,
+    this uses four groups of 8-bit indices; for 8-bit data, 2 groups of 16-bit
+    indices.
+
+    The `cbsz` and `abid` parameters are repurposed to select the index set.
+    If `cbsz == 0`, then `abid[1:0]` selects which index set to use. 
+    If `cbsz != 0`, then the very first is selected.
+
+    Example:
+    ```mlir
+      %0 = amdgpu.sparse_mfma 16x16x32 %matA * %matB + %matC sparse(%idx)
+        : vector<4xf16>, vector<8xf16>, vector<4xf32>
+
+      %1 = amdgpu.sparse_mfma 16x16x64 %matA * %matB + %matC sparse(%idx)
+        : vector<8xf16>, vector<16xf16>, vector<4xf32>
+
+      %2 = amdgpu.sparse_mfma 16x16x128 %matA * %matB + %matC sparse(%idx)
+        { cbsz = 0 : i32, abid = 1 : i32 }
+        : vector<4xi32>, vector<8xi32>, vector<4xi32>
+    ```
+  }];
+  let assemblyFormat = [{
+    custom<MNKDimensionList>($m, $n, $k) $sourceA `*` $sourceB `+` $destC
+    `sparse` `(` $sparseIdx `)`
+    attr-dict
+    `:` type($sourceA) `,` type($sourceB) `,` type($destC)
+  }];
+  let hasVerifier = 1;
+}
+
 def AMDGPU_GatherToLDSOp :
     AMDGPU_Op<"gather_to_lds", [AttrSizedOperandSegments]>,
     Arguments<(ins
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 4b1509392aa6f..855d0c9df4281 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -661,6 +661,30 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
   return input;
 }
 
+/// Converts sparse MFMA (smfmac) operands to the expected ROCDL types.
+static Value convertSparseMFMAVectorOperand(ConversionPatternRewriter &rewriter,
+                                            Location loc, Value input,
+                                            bool allowBf16 = true) {
+  Type inputType = input.getType();
+  if (auto vectorType = dyn_cast<VectorType>(inputType)) {
+    // bf16 -> i16 when not allowed (pre-gfx950)
+    if (vectorType.getElementType().isBF16() && !allowBf16)
+      return LLVM::BitcastOp::create(
+          rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
+    // i8/fp8 vectors -> vector<Nxi32>
+    if (isa<IntegerType>(vectorType.getElementType()) &&
+        vectorType.getElementTypeBitWidth() <= 8) {
+      int64_t numWords = llvm::divideCeil(
+          vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
+          32);
+      return LLVM::BitcastOp::create(
+          rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
+          input);
+    }
+  }
+  return input;
+}
+
 /// Converts the scaled MFMA operands, `scalesA` and `scalesB`, from MLIR AMDGPU
 /// dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
 ///
@@ -1136,6 +1160,104 @@ static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType,
   return std::nullopt;
 }
 
+/// Returns the `rocdl` intrinsic corresponding to a SparseMFMA (smfmac)
+/// operation if one exists. This includes checking to ensure the intrinsic is
+/// supported on the architecture you are compiling for.
+static std::optional<StringRef> smfmacOpToIntrinsic(SparseMFMAOp op,
+                                                    bool isGfx950) {
+  using fp8 = Float8E4M3FNType;
+  using bf8 = Float8E5M2Type;
+
+  uint32_t m = op.getM(), n = op.getN(), k = op.getK();
+  Type sourceAElem = getElementTypeOrSelf(op.getSourceA().getType());
+  Type sourceBElem = getElementTypeOrSelf(op.getSourceB().getType());
+  Type destElem = getElementTypeOrSelf(op.getDestC().getType());
+
+  if (m == 16 && n == 16 && k == 32) {
+    if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
+      return ROCDL::smfmac_f32_16x16x32_f16::getOperationName();
+    if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
+      return ROCDL::smfmac_f32_16x16x32_bf16::getOperationName();
+  }
+
+  if (m == 16 && n == 16 && k == 64) {
+    if (isGfx950) {
+      if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
+        return ROCDL::smfmac_f32_16x16x64_f16::getOperationName();
+      if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
+        return ROCDL::smfmac_f32_16x16x64_bf16::getOperationName();
+    }
+    if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
+        destElem.isInteger(32))
+      return ROCDL::smfmac_i32_16x16x64_i8::getOperationName();
+    if (isa<fp8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_16x16x64_fp8_fp8::getOperationName();
+    if (isa<fp8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_16x16x64_fp8_bf8::getOperationName();
+    if (isa<bf8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_16x16x64_bf8_fp8::getOperationName();
+    if (isa<bf8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_16x16x64_bf8_bf8::getOperationName();
+  }
+
+  if (m == 16 && n == 16 && k == 128 && isGfx950) {
+    if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
+        destElem.isInteger(32))
+      return ROCDL::smfmac_i32_16x16x128_i8::getOperationName();
+    if (isa<fp8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_16x16x128_fp8_fp8::getOperationName();
+    if (isa<fp8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_16x16x128_fp8_bf8::getOperationName();
+    if (isa<bf8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_16x16x128_bf8_fp8::getOperationName();
+    if (isa<bf8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_16x16x128_bf8_bf8::getOperationName();
+  }
+
+  if (m == 32 && n == 32 && k == 16) {
+    if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
+      return ROCDL::smfmac_f32_32x32x16_f16::getOperationName();
+    if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
+      return ROCDL::smfmac_f32_32x32x16_bf16::getOperationName();
+  }
+
+  if (m == 32 && n == 32 && k == 32) {
+    if (isGfx950) {
+      if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
+        return ROCDL::smfmac_f32_32x32x32_f16::getOperationName();
+      if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
+        return ROCDL::smfmac_f32_32x32x32_bf16::getOperationName();
+    }
+    if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
+        destElem.isInteger(32))
+      return ROCDL::smfmac_i32_32x32x32_i8::getOperationName();
+    if (isa<fp8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_32x32x32_fp8_fp8::getOperationName();
+    if (isa<fp8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_32x32x32_fp8_bf8::getOperationName();
+    if (isa<bf8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_32x32x32_bf8_fp8::getOperationName();
+    if (isa<bf8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_32x32x32_bf8_bf8::getOperationName();
+  }
+
+  if (m == 32 && n == 32 && k == 64 && isGfx950) {
+    if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
+        destElem.isInteger(32))
+      return ROCDL::smfmac_i32_32x32x64_i8::getOperationName();
+    if (isa<fp8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_32x32x64_fp8_fp8::getOperationName();
+    if (isa<fp8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_32x32x64_fp8_bf8::getOperationName();
+    if (isa<bf8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_32x32x64_bf8_fp8::getOperationName();
+    if (isa<bf8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+      return ROCDL::smfmac_f32_32x32x64_bf8_bf8::getOperationName();
+  }
+
+  return std::nullopt;
+}
+
 /// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
 /// if one exists. This includes checking to ensure the intrinsic is supported
 /// on the architecture you are compiling for.
@@ -1291,6 +1413,49 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
   }
 };
 
+struct SparseMFMAOpLowering : public ConvertOpToLLVMPattern<SparseMFMAOp> {
+  SparseMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<SparseMFMAOp>(converter), chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(SparseMFMAOp op, SparseMFMAOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto outType =
+        typeConverter->convertType<VectorType>(op.getDestC().getType());
+    if (!outType)
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+    // smfmac is supported on gfx942 and gfx950
+    if (chipset.majorVersion != 9 || chipset < kGfx942)
+      return op->emitOpError("sparse MFMA (smfmac) only supported on gfx942+");
+    bool isGfx950 = chipset >= kGfx950;
+
+    Value a = convertSparseMFMAVectorOperand(rewriter, loc,
+                                             adaptor.getSourceA(), isGfx950);
+    Value b = convertSparseMFMAVectorOperand(rewriter, loc,
+                                             adaptor.getSourceB(), isGfx950);
+    Value c = adaptor.getDestC();
+
+    std::optional<StringRef> maybeIntrinsic = smfmacOpToIntrinsic(op, isGfx950);
+
+    if (!maybeIntrinsic.has_value())
+      return op.emitOpError(
+          "no intrinsic matching sparse MFMA on the given chipset");
+
+    OperationState loweredOp(loc, maybeIntrinsic.value());
+    loweredOp.addTypes(outType);
+    loweredOp.addOperands({a, b, c, adaptor.getSparseIdx(),
+                           createI32Constant(rewriter, loc, op.getCbsz()),
+                           createI32Constant(rewriter, loc, op.getAbid())});
+    Value lowered = rewriter.create(loweredOp)->getResult(0);
+    rewriter.replaceOp(op, lowered);
+    return success();
+  }
+};
+
 struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
       : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
@@ -2797,11 +2962,12 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
       RawBufferOpLowering<RawBufferAtomicCmpswapOp,
                           ROCDL::RawPtrBufferAtomicCmpSwap>,
       AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
-      SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
-      WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering,
-      ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
-      PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
-      GatherToLDSOpLowering, TransposeLoadOpLowering, AMDGPUPermlaneLowering,
+      SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering, 
+      WMMAOpLowering, SparseMFMAOpLowering, ExtPackedFp8OpLowering,
+      ScaledExtPackedMatrixOpLowering, ScaledExtPackedOpLowering,
+      PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
+      PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
+      TransposeLoadOpLowering, AMDGPUPermlaneLowering,
       AMDGPUMakeDmaBaseLowering, AMDGPUMakeDmaDescriptorLowering>(converter,
                                                                   chipset);
   patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index b7a665b0f5367..2cc1aaa8e3b2d 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -522,6 +522,62 @@ LogicalResult MFMAOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// SparseMFMAOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult SparseMFMAOp::verify() {
+  constexpr uint32_t waveSize = 64;
+
+  auto sparseType = cast<VectorType>(getSourceA().getType());
+  auto denseType = cast<VectorType>(getSourceB().getType());
+  auto destType = cast<VectorType>(getDestC().getType());
+
+  Type sparseElem = sparseType.getElementType();
+  Type denseElem = denseType.getElementType();
+  int64_t sparseLen = sparseType.getNumElements();
+  int64_t denseLen = denseType.getNumElements();
+  int64_t destLen = destType.getNumElements();
+
+  if (denseLen != 2 * sparseLen)
+    return emitOpError("expected dense source operand to have exactly double "
+                       "the number of elements of the sparse source operand");
+
+  // Check that source element types are compatible.
+  // For fp8/bf8 mixed operations, element types can differ (e.g., fp8 * bf8).
+  // For other types, element types must match exactly.
+  bool bothFloat8 = sparseElem.isFloat(8) && denseElem.isFloat(8);
+  if (!bothFloat8 && sparseElem != denseElem)
+    return emitOpError(
+        "expected source operands to have the same element type");
+
+  // When CBSZ == 0, ABID selects the index set within the sparse index VGPR.
+  // When CBSZ != 0, the first index set is always used (ABID ignored).
+  bool is8BitSource = sparseElem.isFloat(8) || sparseElem.isInteger(8);
+  if (getCbsz() == 0 && is8BitSource) {
+    // 8-bit source: ABID[0] selects one of two 16-bit index sets.
+    if (getAbid() > 1)
+      return emitOpError(
+          "ABID must be 0 or 1 for 8-bit source data when CBSZ is 0");
+  }
+  // 16-bit source: ABID[1:0] selects one of four 8-bit index sets (0-3 all
+  // valid).
+
+  int64_t expectedSourceElems = (getM() * getK()) / waveSize;
+  if (denseLen != expectedSourceElems)
+    return emitOpError("expected " + Twine(expectedSourceElems) +
+                       " source values for this operation but got " +
+                       Twine(denseLen));
+
+  int64_t expectedDestElems = (getM() * getN()) / waveSize;
+  if (destLen != expectedDestElems)
+    return emitOpError("expected " + Twine(expectedDestElems) +
+                       " result values for this operation but got " +
+                       Twine(destLen));
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // DPPOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma-gfx950.mlir b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma-gfx950.mlir
new file mode 100644
index 0000000000000..abe2565f7c41b
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma-gfx950.mlir
@@ -0,0 +1,53 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx950 -cse | FileCheck %s
+func.func @sparse_mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf16>,
+                                %arg2 : vector<4xf32>, %arg3 : vector<16xf32>,
+                                %arg4 : vector<8xbf16>, %arg5 : vector<16xbf16>,
+                                %arg6 : vector<16xi8>, %arg7 : vector<32xi8>,
+                                %arg8 : vector<4xi32>, %arg9 : vector<16xi32>,
+                                %arg10 : vector<16xf8E4M3FN>, %arg11 : vector<16xf8E5M2>,
+                                %arg12 : vector<32xf8E4M3FN>, %arg13 : vector<32xf8E5M2>,
+                                %arg14 : i32) {
+  // CHECK: rocdl.smfmac.f32.16x16x64.f16{{.*}}: (vector<8xf16>, vector<16xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.sparse_mfma 16x16x64 %arg0 * %arg1 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf16>, vector<16xf16>, vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x64.bf16{{.*}}: (vector<8xbf16>, vector<16xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.sparse_mfma 16x16x64 %arg4 * %arg5 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xbf16>, vector<16xbf16>, vector<4xf32>
+
+  // CHECK: rocdl.smfmac.i32.16x16x128.i8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+  amdgpu.sparse_mfma 16x16x128 %arg6 * %arg7 + %arg8 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xi8>, vector<32xi8>, vector<4xi32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x128.fp8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.sparse_mfma 16x16x128 %arg10 * %arg12 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E4M3FN>, vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x128.bf8.bf8 {{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.sparse_mfma 16x16x128 %arg11 * %arg13 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E5M2>, vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x128.fp8.bf8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.sparse_mfma 16x16x128 %arg10 * %arg13 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E5M2>, vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x128.bf8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.sparse_mfma 16x16x128 %arg11 * %arg12 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E4M3FN>, vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x32.f16{{.*}}: (vector<8xf16>, vector<16xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.sparse_mfma 32x32x32 %arg0 * %arg1 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf16>, vector<16xf16>, vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x32.bf16{{.*}}: (vector<8xbf16>, vector<16xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.sparse_mfma 32x32x32 %arg4 * %arg5 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xbf16>, vector<16xbf16>, vector<16xf32>
+
+  // CHECK: rocdl.smfmac.i32.32x32x64.i8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
+  amdgpu.sparse_mfma 32x32x64 %arg6 * %arg7 + %arg9 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xi8>, vector<32xi8>, vector<16xi32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x64.fp8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.sparse_mfma 32x32x64 %arg10 * %arg12 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E4M3FN>, vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x64.bf8.bf8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.sparse_mfma 32x32x64 %arg11 * %arg13 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E5M2>, vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x64.fp8.bf8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.sparse_mfma 32x32x64 %arg10 * %arg13 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E5M2>, vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x64.bf8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.sparse_mfma 32x32x64 %arg11 * %arg12 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E4M3FN>, vector<16xf32>
+
+  func.return
+}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir
new file mode 100644
index 0000000000000..65a0cd3f1f87f
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir
@@ -0,0 +1,61 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx942 -cse | FileCheck %s
+func.func @sparse_mfma_to_rocdl(%arg0 : vector<4xf16>, %arg1 : vector<8xf16>,
+                                %arg2 : vector<4xf32>, %arg3 : vector<16xf32>,
+                                %arg4 : vector<4xbf16>, %arg5 : vector<8xbf16>,
+                                %arg6 : vector<8xi8>, %arg7 : vector<16xi8>,
+                                %arg8 : vector<4xi32>, %arg9 : vector<16xi32>,
+                                %arg10 : vector<8xf8E4M3FN>, %arg11 : vector<8xf8E5M2>,
+                                %arg12 : vector<16xf8E4M3FN>, %arg13 : vector<16xf8E5M2>,
+                                %arg14 : i32) {
+  // CHECK: rocdl.smfmac.f32.16x16x32.f16{{.*}}: (vector<4xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.sparse_mfma 16x16x32 %arg0 * %arg1 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xf16>, vector<8xf16>, vector<4xf32>
+
+  // CHECK-NEXT: llvm.bitcast {{.*}} : vector<4xbf16> to vector<4xi16>
+  // CHECK-NEXT: llvm.bitcast {{.*}} : vector<8xbf16> to vector<8xi16>
+  // CHECK: rocdl.smfmac.f32.16x16x32.bf16 {{.*}}: (vector<4xi16>, vector<8xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.sparse_mfma 16x16x32 %arg4 * %arg5 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xbf16>, vector<8xbf16>, vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x16.f16{{.*}}: (vector<4xf16>, vector<8xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.sparse_mfma 32x32x16 %arg0 * %arg1 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xf16>, vector<8xf16>, vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x16.bf16 {{.*}}: (vector<4xi16>, vector<8xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.sparse_mfma 32x32x16 %arg4 * %arg5 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xbf16>, vector<8xbf16>, vector<16xf32>
+
+  // CHECK-NEXT: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32>
+  // CHECK-NEXT: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
+  // CHECK: rocdl.smfmac.i32.16x16x64.i8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+  amdgpu.sparse_mfma 16x16x64 %arg6 * %arg7 + %arg8 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xi8>, vector<16xi8>, vector<4xi32>
+
+  // CHECK-NEXT: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32>
+  // CHECK-NEXT: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
+  // CHECK: rocdl.smfmac.f32.16x16x64.fp8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.sparse_mfma 16x16x64 %arg10 * %arg12 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FN>, vector<16xf8E4M3FN>, vector<4xf32>
+
+  // CHECK-NEXT: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32>
+  // CHECK-NEXT: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
+  // CHECK: rocdl.smfmac.f32.16x16x64.bf8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.sparse_mfma 16x16x64 %arg11 * %arg13 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2>, vector<16xf8E5M2>, vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x64.fp8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.sparse_mfma 16x16x64 %arg10 * %arg13 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FN>, vector<16xf8E5M2>, vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x64.bf8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  amdgpu.sparse_mfma 16x16x64 %arg11 * %arg12 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2>, vector<16xf8E4M3FN>, vector<4xf32>
+
+  // CHECK: rocdl.smfmac.i32.32x32x32.i8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
+  amdgpu.sparse_mfma 32x32x32 %arg6 * %arg7 + %arg9 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xi8>, vector<16xi8>, vector<16xi32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x32.fp8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.sparse_mfma 32x32x32 %arg10 * %arg12 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FN>, vector<16xf8E4M3FN>, vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x32.bf8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.sparse_mfma 32x32x32 %arg11 * %arg13 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2>, vector<16xf8E5M2>, vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x32.fp8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.sparse_mfma 32x32x32 %arg10 * %arg13 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FN>, vector<16xf8E5M2>, vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x32.bf8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  amdgpu.sparse_mfma 32x32x32 %arg11 * %arg12 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2>, vector<16xf8E4M3FN>, vector<16xf32>
+
+  func.return
+}

>From e05394559965b8fdb6001313d651242683d049ed Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Thu, 11 Dec 2025 22:30:42 -0800
Subject: [PATCH 02/11] nits

Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 14 ++++++--------
 1 file changed, 6 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 2cc1aaa8e3b2d..4231014c77982 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -554,14 +554,12 @@ LogicalResult SparseMFMAOp::verify() {
   // When CBSZ == 0, ABID selects the index set within the sparse index VGPR.
   // When CBSZ != 0, the first index set is always used (ABID ignored).
   bool is8BitSource = sparseElem.isFloat(8) || sparseElem.isInteger(8);
-  if (getCbsz() == 0 && is8BitSource) {
-    // 8-bit source: ABID[0] selects one of two 16-bit index sets.
-    if (getAbid() > 1)
-      return emitOpError(
-          "ABID must be 0 or 1 for 8-bit source data when CBSZ is 0");
-  }
-  // 16-bit source: ABID[1:0] selects one of four 8-bit index sets (0-3 all
-  // valid).
+  // 8-bit source: ABID selects one of two 16-bit index sets.
+  if (getCbsz() == 0 && is8BitSource && getAbid() > 1)
+    return emitOpError("ABID must be 0 or 1 for 8-bit source data");
+  // 16-bit source: ABID selects one of four 8-bit index sets (0-3 all valid).
+  if (getCbsz() == 0 && !is8BitSource && getAbid() > 3)
+    return emitOpError("ABID must be between 0 and 3 for 16-bit source data");
 
   int64_t expectedSourceElems = (getM() * getK()) / waveSize;
   if (denseLen != expectedSourceElems)

>From f3d945681062873b967dc4c5711b495d85816d67 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Thu, 11 Dec 2025 22:36:58 -0800
Subject: [PATCH 03/11] nit

Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 9b4947049c388..b23343f5cdd1e 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -1139,7 +1139,7 @@ def AMDGPU_SparseMFMAOp :
     multiply-accumulate operations using 2:4 structured sparsity on matrix A
     with dense matrices B, C, and D.
 
-    On gfx940, smfmac intrinsics support:
+    On gfx942, smfmac intrinsics support:
       - M=N=16, K=32 and M=N=32, K=16 for f16 and bf16 sources
       - M=N=16, K=64 and M=N=32, K=32 for i8 and fp8 sources
 

>From 248ad159f19e988c2a3eaebc9672a1ce97b9c307 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Thu, 11 Dec 2025 22:39:01 -0800
Subject: [PATCH 04/11] format thing

Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 46 +++++++++----------
 1 file changed, 23 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 855d0c9df4281..65a68d417b663 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -2947,28 +2947,28 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                                    RewritePatternSet &patterns,
                                                    Chipset chipset) {
   populateAMDGPUTypeAndAttributeConversions(converter);
-  patterns.add<
-      FatRawBufferCastLowering,
-      RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
-      RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
-      RawBufferOpLowering<RawBufferAtomicFaddOp,
-                          ROCDL::RawPtrBufferAtomicFaddOp>,
-      RawBufferOpLowering<RawBufferAtomicFmaxOp,
-                          ROCDL::RawPtrBufferAtomicFmaxOp>,
-      RawBufferOpLowering<RawBufferAtomicSmaxOp,
-                          ROCDL::RawPtrBufferAtomicSmaxOp>,
-      RawBufferOpLowering<RawBufferAtomicUminOp,
-                          ROCDL::RawPtrBufferAtomicUminOp>,
-      RawBufferOpLowering<RawBufferAtomicCmpswapOp,
-                          ROCDL::RawPtrBufferAtomicCmpSwap>,
-      AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
-      SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering, 
-      WMMAOpLowering, SparseMFMAOpLowering, ExtPackedFp8OpLowering,
-      ScaledExtPackedMatrixOpLowering, ScaledExtPackedOpLowering,
-      PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
-      PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
-      TransposeLoadOpLowering, AMDGPUPermlaneLowering,
-      AMDGPUMakeDmaBaseLowering, AMDGPUMakeDmaDescriptorLowering>(converter,
-                                                                  chipset);
+  patterns
+      .add<FatRawBufferCastLowering,
+           RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
+           RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
+           RawBufferOpLowering<RawBufferAtomicFaddOp,
+                               ROCDL::RawPtrBufferAtomicFaddOp>,
+           RawBufferOpLowering<RawBufferAtomicFmaxOp,
+                               ROCDL::RawPtrBufferAtomicFmaxOp>,
+           RawBufferOpLowering<RawBufferAtomicSmaxOp,
+                               ROCDL::RawPtrBufferAtomicSmaxOp>,
+           RawBufferOpLowering<RawBufferAtomicUminOp,
+                               ROCDL::RawPtrBufferAtomicUminOp>,
+           RawBufferOpLowering<RawBufferAtomicCmpswapOp,
+                               ROCDL::RawPtrBufferAtomicCmpSwap>,
+           AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
+           SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
+           WMMAOpLowering, SparseMFMAOpLowering, ExtPackedFp8OpLowering,
+           ScaledExtPackedMatrixOpLowering, ScaledExtPackedOpLowering,
+           PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
+           PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
+           TransposeLoadOpLowering, AMDGPUPermlaneLowering,
+           AMDGPUMakeDmaBaseLowering, AMDGPUMakeDmaDescriptorLowering>(
+          converter, chipset);
   patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
 }

>From b18719d7811e411b37f6dddc7054b71950ab8518 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Thu, 11 Dec 2025 22:59:09 -0800
Subject: [PATCH 05/11] nit

Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
 mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 65a68d417b663..c0f089b9fdb67 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -667,11 +667,11 @@ static Value convertSparseMFMAVectorOperand(ConversionPatternRewriter &rewriter,
                                             bool allowBf16 = true) {
   Type inputType = input.getType();
   if (auto vectorType = dyn_cast<VectorType>(inputType)) {
-    // bf16 -> i16 when not allowed (pre-gfx950)
+    // bf16 -> i16 when not allowed (pre-gfx950).
     if (vectorType.getElementType().isBF16() && !allowBf16)
       return LLVM::BitcastOp::create(
           rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
-    // i8/fp8 vectors -> vector<Nxi32>
+    // i8/fp8 vectors -> vector<Nxi32>.
     if (isa<IntegerType>(vectorType.getElementType()) &&
         vectorType.getElementTypeBitWidth() <= 8) {
       int64_t numWords = llvm::divideCeil(

>From a68c02bd3684a3456605281b013f692b89375f48 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Fri, 12 Dec 2025 09:48:01 -0800
Subject: [PATCH 06/11] nits

Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td |  6 +-
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 69 +++++++++----------
 .../Conversion/AMDGPUToROCDL/sparse-mfma.mlir | 20 +++---
 3 files changed, 47 insertions(+), 48 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index b23343f5cdd1e..7d8c41b4c95cb 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -964,14 +964,16 @@ def SMFMACSparseInTypes : AnyTypeOf<[
     VectorOfLengthAndType<[4, 8], [F16]>,
     VectorOfLengthAndType<[4, 8], [BF16]>,
     VectorOfLengthAndType<[8, 16], [I8]>,
-    VectorOfLengthAndType<[8, 16], [F8E4M3FN, F8E5M2]>
+    VectorOfLengthAndType<[8, 16], [F8E4M3FN, F8E5M2]>,
+    VectorOfLengthAndType<[8, 16], [F8E4M3FNUZ, F8E5M2FNUZ]>
 ]>;
 
 def SMFMACDenseInTypes : AnyTypeOf<[
     VectorOfLengthAndType<[8, 16], [F16]>,
     VectorOfLengthAndType<[8, 16], [BF16]>,
     VectorOfLengthAndType<[16, 32], [I8]>,
-    VectorOfLengthAndType<[16, 32], [F8E4M3FN, F8E5M2]>
+    VectorOfLengthAndType<[16, 32], [F8E4M3FN, F8E5M2]>,
+    VectorOfLengthAndType<[16, 32], [F8E4M3FNUZ, F8E5M2FNUZ]>
 ]>;
 
 def SMFMACOutTypes : AnyTypeOf<[
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index c0f089b9fdb67..51388cbfa458e 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -666,21 +666,18 @@ static Value convertSparseMFMAVectorOperand(ConversionPatternRewriter &rewriter,
                                             Location loc, Value input,
                                             bool allowBf16 = true) {
   Type inputType = input.getType();
-  if (auto vectorType = dyn_cast<VectorType>(inputType)) {
-    // bf16 -> i16 when not allowed (pre-gfx950).
-    if (vectorType.getElementType().isBF16() && !allowBf16)
-      return LLVM::BitcastOp::create(
-          rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
-    // i8/fp8 vectors -> vector<Nxi32>.
-    if (isa<IntegerType>(vectorType.getElementType()) &&
-        vectorType.getElementTypeBitWidth() <= 8) {
-      int64_t numWords = llvm::divideCeil(
-          vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
-          32);
-      return LLVM::BitcastOp::create(
-          rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
-          input);
-    }
+  auto vectorType = cast<VectorType>(inputType);
+  // bf16 -> i16 when not allowed (pre-gfx950).
+  if (vectorType.getElementType().isBF16() && !allowBf16)
+    return LLVM::BitcastOp::create(
+        rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
+  // i8/fp8 vectors -> vector<Nxi32>.
+  if (isa<IntegerType>(vectorType.getElementType()) &&
+      vectorType.getElementTypeBitWidth() <= 8) {
+    int64_t numWords = llvm::divideCeil(
+        vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32);
+    return LLVM::BitcastOp::create(
+        rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()), input);
   }
   return input;
 }
@@ -1164,9 +1161,10 @@ static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType,
 /// operation if one exists. This includes checking to ensure the intrinsic is
 /// supported on the architecture you are compiling for.
 static std::optional<StringRef> smfmacOpToIntrinsic(SparseMFMAOp op,
-                                                    bool isGfx950) {
-  using fp8 = Float8E4M3FNType;
-  using bf8 = Float8E5M2Type;
+                                                    Chipset chipset) {
+  bool isGfx950 = chipset >= kGfx950;
+  auto isFp8 = [&](Type t) { return typeIsExpectedFp8ForChipset(chipset, t); };
+  auto isBf8 = [&](Type t) { return typeIsExpectedBf8ForChipset(chipset, t); };
 
   uint32_t m = op.getM(), n = op.getN(), k = op.getK();
   Type sourceAElem = getElementTypeOrSelf(op.getSourceA().getType());
@@ -1190,13 +1188,13 @@ static std::optional<StringRef> smfmacOpToIntrinsic(SparseMFMAOp op,
     if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
         destElem.isInteger(32))
       return ROCDL::smfmac_i32_16x16x64_i8::getOperationName();
-    if (isa<fp8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+    if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
       return ROCDL::smfmac_f32_16x16x64_fp8_fp8::getOperationName();
-    if (isa<fp8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+    if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
       return ROCDL::smfmac_f32_16x16x64_fp8_bf8::getOperationName();
-    if (isa<bf8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+    if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
       return ROCDL::smfmac_f32_16x16x64_bf8_fp8::getOperationName();
-    if (isa<bf8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+    if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
       return ROCDL::smfmac_f32_16x16x64_bf8_bf8::getOperationName();
   }
 
@@ -1204,13 +1202,13 @@ static std::optional<StringRef> smfmacOpToIntrinsic(SparseMFMAOp op,
     if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
         destElem.isInteger(32))
       return ROCDL::smfmac_i32_16x16x128_i8::getOperationName();
-    if (isa<fp8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+    if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
       return ROCDL::smfmac_f32_16x16x128_fp8_fp8::getOperationName();
-    if (isa<fp8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+    if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
       return ROCDL::smfmac_f32_16x16x128_fp8_bf8::getOperationName();
-    if (isa<bf8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+    if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
       return ROCDL::smfmac_f32_16x16x128_bf8_fp8::getOperationName();
-    if (isa<bf8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+    if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
       return ROCDL::smfmac_f32_16x16x128_bf8_bf8::getOperationName();
   }
 
@@ -1231,13 +1229,13 @@ static std::optional<StringRef> smfmacOpToIntrinsic(SparseMFMAOp op,
     if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
         destElem.isInteger(32))
       return ROCDL::smfmac_i32_32x32x32_i8::getOperationName();
-    if (isa<fp8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+    if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
       return ROCDL::smfmac_f32_32x32x32_fp8_fp8::getOperationName();
-    if (isa<fp8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+    if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
       return ROCDL::smfmac_f32_32x32x32_fp8_bf8::getOperationName();
-    if (isa<bf8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+    if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
       return ROCDL::smfmac_f32_32x32x32_bf8_fp8::getOperationName();
-    if (isa<bf8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+    if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
       return ROCDL::smfmac_f32_32x32x32_bf8_bf8::getOperationName();
   }
 
@@ -1245,13 +1243,13 @@ static std::optional<StringRef> smfmacOpToIntrinsic(SparseMFMAOp op,
     if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
         destElem.isInteger(32))
       return ROCDL::smfmac_i32_32x32x64_i8::getOperationName();
-    if (isa<fp8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+    if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
       return ROCDL::smfmac_f32_32x32x64_fp8_fp8::getOperationName();
-    if (isa<fp8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+    if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
       return ROCDL::smfmac_f32_32x32x64_fp8_bf8::getOperationName();
-    if (isa<bf8>(sourceAElem) && isa<fp8>(sourceBElem) && destElem.isF32())
+    if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
       return ROCDL::smfmac_f32_32x32x64_bf8_fp8::getOperationName();
-    if (isa<bf8>(sourceAElem) && isa<bf8>(sourceBElem) && destElem.isF32())
+    if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
       return ROCDL::smfmac_f32_32x32x64_bf8_bf8::getOperationName();
   }
 
@@ -1439,8 +1437,7 @@ struct SparseMFMAOpLowering : public ConvertOpToLLVMPattern<SparseMFMAOp> {
                                              adaptor.getSourceB(), isGfx950);
     Value c = adaptor.getDestC();
 
-    std::optional<StringRef> maybeIntrinsic = smfmacOpToIntrinsic(op, isGfx950);
-
+    std::optional<StringRef> maybeIntrinsic = smfmacOpToIntrinsic(op, chipset);
     if (!maybeIntrinsic.has_value())
       return op.emitOpError(
           "no intrinsic matching sparse MFMA on the given chipset");
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir
index 65a0cd3f1f87f..a1784ce95de49 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir
@@ -4,8 +4,8 @@ func.func @sparse_mfma_to_rocdl(%arg0 : vector<4xf16>, %arg1 : vector<8xf16>,
                                 %arg4 : vector<4xbf16>, %arg5 : vector<8xbf16>,
                                 %arg6 : vector<8xi8>, %arg7 : vector<16xi8>,
                                 %arg8 : vector<4xi32>, %arg9 : vector<16xi32>,
-                                %arg10 : vector<8xf8E4M3FN>, %arg11 : vector<8xf8E5M2>,
-                                %arg12 : vector<16xf8E4M3FN>, %arg13 : vector<16xf8E5M2>,
+                                %arg10 : vector<8xf8E4M3FNUZ>, %arg11 : vector<8xf8E5M2FNUZ>,
+                                %arg12 : vector<16xf8E4M3FNUZ>, %arg13 : vector<16xf8E5M2FNUZ>,
                                 %arg14 : i32) {
   // CHECK: rocdl.smfmac.f32.16x16x32.f16{{.*}}: (vector<4xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
   amdgpu.sparse_mfma 16x16x32 %arg0 * %arg1 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xf16>, vector<8xf16>, vector<4xf32>
@@ -29,33 +29,33 @@ func.func @sparse_mfma_to_rocdl(%arg0 : vector<4xf16>, %arg1 : vector<8xf16>,
   // CHECK-NEXT: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32>
   // CHECK-NEXT: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
   // CHECK: rocdl.smfmac.f32.16x16x64.fp8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
-  amdgpu.sparse_mfma 16x16x64 %arg10 * %arg12 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FN>, vector<16xf8E4M3FN>, vector<4xf32>
+  amdgpu.sparse_mfma 16x16x64 %arg10 * %arg12 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E4M3FNUZ>, vector<4xf32>
 
   // CHECK-NEXT: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32>
   // CHECK-NEXT: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
   // CHECK: rocdl.smfmac.f32.16x16x64.bf8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
-  amdgpu.sparse_mfma 16x16x64 %arg11 * %arg13 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2>, vector<16xf8E5M2>, vector<4xf32>
+  amdgpu.sparse_mfma 16x16x64 %arg11 * %arg13 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E5M2FNUZ>, vector<4xf32>
 
   // CHECK: rocdl.smfmac.f32.16x16x64.fp8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
-  amdgpu.sparse_mfma 16x16x64 %arg10 * %arg13 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FN>, vector<16xf8E5M2>, vector<4xf32>
+  amdgpu.sparse_mfma 16x16x64 %arg10 * %arg13 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E5M2FNUZ>, vector<4xf32>
 
   // CHECK: rocdl.smfmac.f32.16x16x64.bf8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
-  amdgpu.sparse_mfma 16x16x64 %arg11 * %arg12 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2>, vector<16xf8E4M3FN>, vector<4xf32>
+  amdgpu.sparse_mfma 16x16x64 %arg11 * %arg12 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E4M3FNUZ>, vector<4xf32>
 
   // CHECK: rocdl.smfmac.i32.32x32x32.i8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
   amdgpu.sparse_mfma 32x32x32 %arg6 * %arg7 + %arg9 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xi8>, vector<16xi8>, vector<16xi32>
 
   // CHECK: rocdl.smfmac.f32.32x32x32.fp8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
-  amdgpu.sparse_mfma 32x32x32 %arg10 * %arg12 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FN>, vector<16xf8E4M3FN>, vector<16xf32>
+  amdgpu.sparse_mfma 32x32x32 %arg10 * %arg12 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E4M3FNUZ>, vector<16xf32>
 
   // CHECK: rocdl.smfmac.f32.32x32x32.bf8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
-  amdgpu.sparse_mfma 32x32x32 %arg11 * %arg13 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2>, vector<16xf8E5M2>, vector<16xf32>
+  amdgpu.sparse_mfma 32x32x32 %arg11 * %arg13 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E5M2FNUZ>, vector<16xf32>
 
   // CHECK: rocdl.smfmac.f32.32x32x32.fp8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
-  amdgpu.sparse_mfma 32x32x32 %arg10 * %arg13 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FN>, vector<16xf8E5M2>, vector<16xf32>
+  amdgpu.sparse_mfma 32x32x32 %arg10 * %arg13 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E5M2FNUZ>, vector<16xf32>
 
   // CHECK: rocdl.smfmac.f32.32x32x32.bf8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
-  amdgpu.sparse_mfma 32x32x32 %arg11 * %arg12 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2>, vector<16xf8E4M3FN>, vector<16xf32>
+  amdgpu.sparse_mfma 32x32x32 %arg11 * %arg12 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E4M3FNUZ>, vector<16xf32>
 
   func.return
 }

>From 407862a654fa673579d448ea160c99bdc52c19e6 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Fri, 12 Dec 2025 09:49:02 -0800
Subject: [PATCH 07/11] nit

Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
 mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 51388cbfa458e..66136ab547022 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1426,7 +1426,7 @@ struct SparseMFMAOpLowering : public ConvertOpToLLVMPattern<SparseMFMAOp> {
     if (!outType)
       return rewriter.notifyMatchFailure(op, "type conversion failed");
 
-    // smfmac is supported on gfx942 and gfx950
+    // smfmac is supported on gfx942 and gfx950.
     if (chipset.majorVersion != 9 || chipset < kGfx942)
       return op->emitOpError("sparse MFMA (smfmac) only supported on gfx942+");
     bool isGfx950 = chipset >= kGfx950;

>From 6ad2e3ac4594ade6df985320ef466c9602244061 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Tue, 16 Dec 2025 14:52:01 -0800
Subject: [PATCH 08/11] format

Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
 mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 11 +++++------
 1 file changed, 5 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index f1e74d4c6f61c..b538e09441c83 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -3530,12 +3530,11 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
            AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
            SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
            SparseMFMAOpLowering, WMMAOpLowering, ScaledWMMAOpLowering,
-           ExtPackedFp8OpLowering,
-           ScaledExtPackedMatrixOpLowering, ScaledExtPackedOpLowering,
-           PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
-           PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
-           TransposeLoadOpLowering, AMDGPUPermlaneLowering,
-           AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
+           ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering,
+           ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
+           PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
+           GatherToLDSOpLowering, TransposeLoadOpLowering,
+           AMDGPUPermlaneLowering, AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
            AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
            AMDGPULowerDescriptor<MakeDmaDescriptorOp>,
            AMDGPULowerDescriptor<MakeGatherDmaDescriptorOp>,

>From 46fa5061c5ef8b05d10e7fe3cc3244e1f585dc08 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Thu, 18 Dec 2025 12:20:41 -0800
Subject: [PATCH 09/11] address review

Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 27 ++++----
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           |  6 +-
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  | 18 ++++++
 .../AMDGPUToROCDL/sparse-mfma-gfx950.mlir     | 38 ++++++-----
 .../Conversion/AMDGPUToROCDL/sparse-mfma.mlir | 48 +++++++-------
 mlir/test/Dialect/AMDGPU/invalid.mlir         | 64 +++++++++++++++++++
 6 files changed, 151 insertions(+), 50 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index dbc01c33f2853..a2bb755f40b16 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -1013,6 +1013,11 @@ def SMFMACOutTypes : AnyTypeOf<[
     VectorOfLengthAndType<[4, 16], [I32]>
 ]>;
 
+def SparseMFMAIdxTypes : AnyTypeOf<[
+    FixedVectorOfLengthAndType<[4], [I8]>,
+    FixedVectorOfLengthAndType<[2], [I16]>
+]>;
+
 // scaled_mfma
 def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN]>,
                                    VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
@@ -1171,7 +1176,7 @@ def AMDGPU_SparseMFMAOp :
                    SMFMACSparseInTypes:$sourceA,
                    SMFMACDenseInTypes:$sourceB,
                    SMFMACOutTypes:$destC,
-                   I32:$sparseIdx,
+                   SparseMFMAIdxTypes:$sparseIdx,
                    DefaultValuedAttr<I32Attr, "0">:$cbsz,
                    DefaultValuedAttr<I32Attr, "0">:$abid)>,
     Results<(outs SMFMACOutTypes: $destD)> {
@@ -1190,10 +1195,10 @@ def AMDGPU_SparseMFMAOp :
       - M=N=16, K=64 and M=N=32, K=32 for f16 and bf16 sources
       - M=N=16, K=128 and M=N=32, K=64 for i8 and fp8 sources
 
-    The `sparseIdx` parameter (i32) contains packed indices identifying the
-    positions of non-zero elements in the 2:4 sparse matrix A. For 16-bit data,
-    this uses four groups of 8-bit indices; for 8-bit data, 2 groups of 16-bit
-    indices.
+    The `sparseIdx` parameter contains packed indices identifying the positions
+    of non-zero elements in the 2:4 sparse matrix A. For 16-bit source data,
+    use `vector<4xi8>` (four 8-bit indices). For 8-bit source data, use
+    `vector<2xi16>` (two 16-bit indices).
 
     The `cbsz` and `abid` parameters are repurposed to select the index set.
     If `cbsz == 0`, then `abid[1:0]` selects which index set to use. 
@@ -1201,20 +1206,20 @@ def AMDGPU_SparseMFMAOp :
 
     Example:
     ```mlir
-      %0 = amdgpu.sparse_mfma 16x16x32 %matA * %matB + %matC sparse(%idx)
+      %0 = amdgpu.sparse_mfma 16x16x32 %matA * %matB + %matC sparse(%idx : vector<4xi8>)
         : vector<4xf16>, vector<8xf16>, vector<4xf32>
 
-      %1 = amdgpu.sparse_mfma 16x16x64 %matA * %matB + %matC sparse(%idx)
-        : vector<8xf16>, vector<16xf16>, vector<4xf32>
+      %1 = amdgpu.sparse_mfma 16x16x64 %matA * %matB + %matC sparse(%idx : vector<2xi16>)
+        : vector<8xi8>, vector<16xi8>, vector<4xi32>
 
-      %2 = amdgpu.sparse_mfma 16x16x128 %matA * %matB + %matC sparse(%idx)
+      %2 = amdgpu.sparse_mfma 16x16x64 %matA * %matB + %matC sparse(%idx : vector<2xi16>)
         { cbsz = 0 : i32, abid = 1 : i32 }
-        : vector<4xi32>, vector<8xi32>, vector<4xi32>
+        : vector<8xf8E4M3FNUZ>, vector<16xf8E4M3FNUZ>, vector<4xf32>
     ```
   }];
   let assemblyFormat = [{
     custom<MNKDimensionList>($m, $n, $k) $sourceA `*` $sourceB `+` $destC
-    `sparse` `(` $sparseIdx `)`
+    `sparse` `(` $sparseIdx `:` type($sparseIdx) `)`
     attr-dict
     `:` type($sourceA) `,` type($sourceB) `,` type($destC)
   }];
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index b538e09441c83..5dcd24019412a 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1477,9 +1477,13 @@ struct SparseMFMAOpLowering : public ConvertOpToLLVMPattern<SparseMFMAOp> {
       return op.emitOpError(
           "no intrinsic matching sparse MFMA on the given chipset");
 
+    // Bitcast sparse indices from vector<4xi8> or vector<2xi16> to i32.
+    Value sparseIdx = LLVM::BitcastOp::create(
+        rewriter, loc, rewriter.getI32Type(), adaptor.getSparseIdx());
+
     OperationState loweredOp(loc, maybeIntrinsic.value());
     loweredOp.addTypes(outType);
-    loweredOp.addOperands({a, b, c, adaptor.getSparseIdx(),
+    loweredOp.addOperands({a, b, c, sparseIdx,
                            createI32Constant(rewriter, loc, op.getCbsz()),
                            createI32Constant(rewriter, loc, op.getAbid())});
     Value lowered = rewriter.create(loweredOp)->getResult(0);
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 8bcdaec0bf3b1..31f87aaa3ce74 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -671,6 +671,24 @@ LogicalResult SparseMFMAOp::verify() {
   if (getCbsz() == 0 && !is8BitSource && getAbid() > 3)
     return emitOpError("ABID must be between 0 and 3 for 16-bit source data");
 
+  // Validate sparseIdx type matches source element type.
+  auto sparseIdxType = cast<VectorType>(getSparseIdx().getType());
+  if (is8BitSource) {
+    // 8-bit source data requires vector<2xi16> sparse indices.
+    if (sparseIdxType.getNumElements() != 2 ||
+        !sparseIdxType.getElementType().isInteger(16))
+      return emitOpError("expected vector<2xi16> sparse indices for 8-bit "
+                         "source data, but got ")
+             << sparseIdxType;
+  } else {
+    // 16-bit source data requires vector<4xi8> sparse indices.
+    if (sparseIdxType.getNumElements() != 4 ||
+        !sparseIdxType.getElementType().isInteger(8))
+      return emitOpError("expected vector<4xi8> sparse indices for 16-bit "
+                         "source data, but got ")
+             << sparseIdxType;
+  }
+
   int64_t expectedSourceElems = (getM() * getK()) / waveSize;
   if (denseLen != expectedSourceElems)
     return emitOpError("expected " + Twine(expectedSourceElems) +
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma-gfx950.mlir b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma-gfx950.mlir
index abe2565f7c41b..266e0e7e15595 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma-gfx950.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma-gfx950.mlir
@@ -6,48 +6,56 @@ func.func @sparse_mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf16>,
                                 %arg8 : vector<4xi32>, %arg9 : vector<16xi32>,
                                 %arg10 : vector<16xf8E4M3FN>, %arg11 : vector<16xf8E5M2>,
                                 %arg12 : vector<32xf8E4M3FN>, %arg13 : vector<32xf8E5M2>,
-                                %arg14 : i32) {
+                                %arg14 : vector<4xi8>, %arg15 : vector<2xi16>) {
+  // CHECK: llvm.bitcast %{{.*}} : vector<4xi8> to i32
   // CHECK: rocdl.smfmac.f32.16x16x64.f16{{.*}}: (vector<8xf16>, vector<16xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
-  amdgpu.sparse_mfma 16x16x64 %arg0 * %arg1 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf16>, vector<16xf16>, vector<4xf32>
+  amdgpu.sparse_mfma 16x16x64 %arg0 * %arg1 + %arg2 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf16>, vector<16xf16>, vector<4xf32>
 
   // CHECK: rocdl.smfmac.f32.16x16x64.bf16{{.*}}: (vector<8xbf16>, vector<16xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
-  amdgpu.sparse_mfma 16x16x64 %arg4 * %arg5 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xbf16>, vector<16xbf16>, vector<4xf32>
+  amdgpu.sparse_mfma 16x16x64 %arg4 * %arg5 + %arg2 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xbf16>, vector<16xbf16>, vector<4xf32>
 
+  // CHECK: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
+  // CHECK: llvm.bitcast {{.*}} : vector<32xi8> to vector<8xi32>
+  // CHECK: llvm.bitcast %{{.*}} : vector<2xi16> to i32
   // CHECK: rocdl.smfmac.i32.16x16x128.i8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
-  amdgpu.sparse_mfma 16x16x128 %arg6 * %arg7 + %arg8 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xi8>, vector<32xi8>, vector<4xi32>
+  amdgpu.sparse_mfma 16x16x128 %arg6 * %arg7 + %arg8 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xi8>, vector<32xi8>, vector<4xi32>
 
+  // CHECK: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
+  // CHECK: llvm.bitcast {{.*}} : vector<32xi8> to vector<8xi32>
   // CHECK: rocdl.smfmac.f32.16x16x128.fp8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
-  amdgpu.sparse_mfma 16x16x128 %arg10 * %arg12 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E4M3FN>, vector<4xf32>
+  amdgpu.sparse_mfma 16x16x128 %arg10 * %arg12 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E4M3FN>, vector<4xf32>
 
+  // CHECK: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
+  // CHECK: llvm.bitcast {{.*}} : vector<32xi8> to vector<8xi32>
   // CHECK: rocdl.smfmac.f32.16x16x128.bf8.bf8 {{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
-  amdgpu.sparse_mfma 16x16x128 %arg11 * %arg13 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E5M2>, vector<4xf32>
+  amdgpu.sparse_mfma 16x16x128 %arg11 * %arg13 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E5M2>, vector<4xf32>
 
   // CHECK: rocdl.smfmac.f32.16x16x128.fp8.bf8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
-  amdgpu.sparse_mfma 16x16x128 %arg10 * %arg13 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E5M2>, vector<4xf32>
+  amdgpu.sparse_mfma 16x16x128 %arg10 * %arg13 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E5M2>, vector<4xf32>
 
   // CHECK: rocdl.smfmac.f32.16x16x128.bf8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
-  amdgpu.sparse_mfma 16x16x128 %arg11 * %arg12 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E4M3FN>, vector<4xf32>
+  amdgpu.sparse_mfma 16x16x128 %arg11 * %arg12 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E4M3FN>, vector<4xf32>
 
   // CHECK: rocdl.smfmac.f32.32x32x32.f16{{.*}}: (vector<8xf16>, vector<16xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
-  amdgpu.sparse_mfma 32x32x32 %arg0 * %arg1 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf16>, vector<16xf16>, vector<16xf32>
+  amdgpu.sparse_mfma 32x32x32 %arg0 * %arg1 + %arg3 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf16>, vector<16xf16>, vector<16xf32>
 
   // CHECK: rocdl.smfmac.f32.32x32x32.bf16{{.*}}: (vector<8xbf16>, vector<16xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
-  amdgpu.sparse_mfma 32x32x32 %arg4 * %arg5 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xbf16>, vector<16xbf16>, vector<16xf32>
+  amdgpu.sparse_mfma 32x32x32 %arg4 * %arg5 + %arg3 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xbf16>, vector<16xbf16>, vector<16xf32>
 
   // CHECK: rocdl.smfmac.i32.32x32x64.i8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
-  amdgpu.sparse_mfma 32x32x64 %arg6 * %arg7 + %arg9 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xi8>, vector<32xi8>, vector<16xi32>
+  amdgpu.sparse_mfma 32x32x64 %arg6 * %arg7 + %arg9 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xi8>, vector<32xi8>, vector<16xi32>
 
   // CHECK: rocdl.smfmac.f32.32x32x64.fp8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
-  amdgpu.sparse_mfma 32x32x64 %arg10 * %arg12 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E4M3FN>, vector<16xf32>
+  amdgpu.sparse_mfma 32x32x64 %arg10 * %arg12 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E4M3FN>, vector<16xf32>
 
   // CHECK: rocdl.smfmac.f32.32x32x64.bf8.bf8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
-  amdgpu.sparse_mfma 32x32x64 %arg11 * %arg13 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E5M2>, vector<16xf32>
+  amdgpu.sparse_mfma 32x32x64 %arg11 * %arg13 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E5M2>, vector<16xf32>
 
   // CHECK: rocdl.smfmac.f32.32x32x64.fp8.bf8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
-  amdgpu.sparse_mfma 32x32x64 %arg10 * %arg13 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E5M2>, vector<16xf32>
+  amdgpu.sparse_mfma 32x32x64 %arg10 * %arg13 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E5M2>, vector<16xf32>
 
   // CHECK: rocdl.smfmac.f32.32x32x64.bf8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
-  amdgpu.sparse_mfma 32x32x64 %arg11 * %arg12 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E4M3FN>, vector<16xf32>
+  amdgpu.sparse_mfma 32x32x64 %arg11 * %arg12 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E4M3FN>, vector<16xf32>
 
   func.return
 }
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir
index a1784ce95de49..b2c91c3d9bed1 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir
@@ -6,56 +6,58 @@ func.func @sparse_mfma_to_rocdl(%arg0 : vector<4xf16>, %arg1 : vector<8xf16>,
                                 %arg8 : vector<4xi32>, %arg9 : vector<16xi32>,
                                 %arg10 : vector<8xf8E4M3FNUZ>, %arg11 : vector<8xf8E5M2FNUZ>,
                                 %arg12 : vector<16xf8E4M3FNUZ>, %arg13 : vector<16xf8E5M2FNUZ>,
-                                %arg14 : i32) {
+                                %arg14 : vector<4xi8>, %arg15 : vector<2xi16>) {
+  // CHECK: llvm.bitcast %{{.*}} : vector<4xi8> to i32
   // CHECK: rocdl.smfmac.f32.16x16x32.f16{{.*}}: (vector<4xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
-  amdgpu.sparse_mfma 16x16x32 %arg0 * %arg1 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xf16>, vector<8xf16>, vector<4xf32>
+  amdgpu.sparse_mfma 16x16x32 %arg0 * %arg1 + %arg2 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xf16>, vector<8xf16>, vector<4xf32>
 
-  // CHECK-NEXT: llvm.bitcast {{.*}} : vector<4xbf16> to vector<4xi16>
-  // CHECK-NEXT: llvm.bitcast {{.*}} : vector<8xbf16> to vector<8xi16>
+  // CHECK: llvm.bitcast {{.*}} : vector<4xbf16> to vector<4xi16>
+  // CHECK: llvm.bitcast {{.*}} : vector<8xbf16> to vector<8xi16>
   // CHECK: rocdl.smfmac.f32.16x16x32.bf16 {{.*}}: (vector<4xi16>, vector<8xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
-  amdgpu.sparse_mfma 16x16x32 %arg4 * %arg5 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xbf16>, vector<8xbf16>, vector<4xf32>
+  amdgpu.sparse_mfma 16x16x32 %arg4 * %arg5 + %arg2 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xbf16>, vector<8xbf16>, vector<4xf32>
 
   // CHECK: rocdl.smfmac.f32.32x32x16.f16{{.*}}: (vector<4xf16>, vector<8xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
-  amdgpu.sparse_mfma 32x32x16 %arg0 * %arg1 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xf16>, vector<8xf16>, vector<16xf32>
+  amdgpu.sparse_mfma 32x32x16 %arg0 * %arg1 + %arg3 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xf16>, vector<8xf16>, vector<16xf32>
 
   // CHECK: rocdl.smfmac.f32.32x32x16.bf16 {{.*}}: (vector<4xi16>, vector<8xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
-  amdgpu.sparse_mfma 32x32x16 %arg4 * %arg5 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xbf16>, vector<8xbf16>, vector<16xf32>
+  amdgpu.sparse_mfma 32x32x16 %arg4 * %arg5 + %arg3 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xbf16>, vector<8xbf16>, vector<16xf32>
 
-  // CHECK-NEXT: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32>
-  // CHECK-NEXT: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
+  // CHECK: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32>
+  // CHECK: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
+  // CHECK: llvm.bitcast %{{.*}} : vector<2xi16> to i32
   // CHECK: rocdl.smfmac.i32.16x16x64.i8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
-  amdgpu.sparse_mfma 16x16x64 %arg6 * %arg7 + %arg8 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xi8>, vector<16xi8>, vector<4xi32>
+  amdgpu.sparse_mfma 16x16x64 %arg6 * %arg7 + %arg8 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xi8>, vector<16xi8>, vector<4xi32>
 
-  // CHECK-NEXT: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32>
-  // CHECK-NEXT: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
+  // CHECK: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32>
+  // CHECK: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
   // CHECK: rocdl.smfmac.f32.16x16x64.fp8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
-  amdgpu.sparse_mfma 16x16x64 %arg10 * %arg12 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E4M3FNUZ>, vector<4xf32>
+  amdgpu.sparse_mfma 16x16x64 %arg10 * %arg12 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E4M3FNUZ>, vector<4xf32>
 
-  // CHECK-NEXT: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32>
-  // CHECK-NEXT: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
+  // CHECK: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32>
+  // CHECK: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32>
   // CHECK: rocdl.smfmac.f32.16x16x64.bf8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
-  amdgpu.sparse_mfma 16x16x64 %arg11 * %arg13 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E5M2FNUZ>, vector<4xf32>
+  amdgpu.sparse_mfma 16x16x64 %arg11 * %arg13 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E5M2FNUZ>, vector<4xf32>
 
   // CHECK: rocdl.smfmac.f32.16x16x64.fp8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
-  amdgpu.sparse_mfma 16x16x64 %arg10 * %arg13 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E5M2FNUZ>, vector<4xf32>
+  amdgpu.sparse_mfma 16x16x64 %arg10 * %arg13 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E5M2FNUZ>, vector<4xf32>
 
   // CHECK: rocdl.smfmac.f32.16x16x64.bf8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
-  amdgpu.sparse_mfma 16x16x64 %arg11 * %arg12 + %arg2 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E4M3FNUZ>, vector<4xf32>
+  amdgpu.sparse_mfma 16x16x64 %arg11 * %arg12 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E4M3FNUZ>, vector<4xf32>
 
   // CHECK: rocdl.smfmac.i32.32x32x32.i8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
-  amdgpu.sparse_mfma 32x32x32 %arg6 * %arg7 + %arg9 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xi8>, vector<16xi8>, vector<16xi32>
+  amdgpu.sparse_mfma 32x32x32 %arg6 * %arg7 + %arg9 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xi8>, vector<16xi8>, vector<16xi32>
 
   // CHECK: rocdl.smfmac.f32.32x32x32.fp8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
-  amdgpu.sparse_mfma 32x32x32 %arg10 * %arg12 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E4M3FNUZ>, vector<16xf32>
+  amdgpu.sparse_mfma 32x32x32 %arg10 * %arg12 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E4M3FNUZ>, vector<16xf32>
 
   // CHECK: rocdl.smfmac.f32.32x32x32.bf8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
-  amdgpu.sparse_mfma 32x32x32 %arg11 * %arg13 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E5M2FNUZ>, vector<16xf32>
+  amdgpu.sparse_mfma 32x32x32 %arg11 * %arg13 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E5M2FNUZ>, vector<16xf32>
 
   // CHECK: rocdl.smfmac.f32.32x32x32.fp8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
-  amdgpu.sparse_mfma 32x32x32 %arg10 * %arg13 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E5M2FNUZ>, vector<16xf32>
+  amdgpu.sparse_mfma 32x32x32 %arg10 * %arg13 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E5M2FNUZ>, vector<16xf32>
 
   // CHECK: rocdl.smfmac.f32.32x32x32.bf8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
-  amdgpu.sparse_mfma 32x32x32 %arg11 * %arg12 + %arg3 sparse(%arg14) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E4M3FNUZ>, vector<16xf32>
+  amdgpu.sparse_mfma 32x32x32 %arg11 * %arg12 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E4M3FNUZ>, vector<16xf32>
 
   func.return
 }
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 9ece57e9ec6a3..6785946fa60b1 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -452,3 +452,67 @@ func.func @make_gather_dma_descriptor_invalid_index_types(%base: !amdgpu.tdm_gat
   amdgpu.make_gather_dma_descriptor %base[%indices] globalSize [4, 4] globalStride [1, 1] sharedSize [1, 2] : !amdgpu.tdm_gather_base<i32, i16>, vector<8xi32> -> !amdgpu.tdm_descriptor
   func.return
 }
+
+// -----
+
+func.func @sparse_mfma_dense_not_double_sparse(%a: vector<4xf16>, %b: vector<4xf16>, %c: vector<4xf32>, %idx: vector<4xi8>) -> vector<4xf32> {
+  // expected-error at +1 {{'amdgpu.sparse_mfma' op operand #1 must be vector of 16-bit float values of length 8/16 or vector of bfloat16 type values of length 8/16 or vector of 8-bit signless integer values of length 16/32 or vector of f8E4M3FN type or f8E5M2 type values of length 16/32 or vector of f8E4M3FNUZ type or f8E5M2FNUZ type values of length 16/32, but got 'vector<4xf16>'}}
+  %d = amdgpu.sparse_mfma 16x16x32 %a * %b + %c sparse(%idx : vector<4xi8>) : vector<4xf16>, vector<4xf16>, vector<4xf32>
+  func.return %d : vector<4xf32>
+}
+
+// -----
+
+func.func @sparse_mfma_mismatched_source_types(%a: vector<4xf16>, %b: vector<8xbf16>, %c: vector<4xf32>, %idx: vector<4xi8>) -> vector<4xf32> {
+  // expected-error at +1 {{'amdgpu.sparse_mfma' op expected source operands to have the same element type}}
+  %d = amdgpu.sparse_mfma 16x16x32 %a * %b + %c sparse(%idx : vector<4xi8>) : vector<4xf16>, vector<8xbf16>, vector<4xf32>
+  func.return %d : vector<4xf32>
+}
+
+// -----
+
+func.func @sparse_mfma_abid_invalid_for_8bit(%a: vector<8xi8>, %b: vector<16xi8>, %c: vector<4xi32>, %idx: vector<2xi16>) -> vector<4xi32> {
+  // expected-error at +1 {{'amdgpu.sparse_mfma' op ABID must be 0 or 1 for 8-bit source data}}
+  %d = amdgpu.sparse_mfma 16x16x64 %a * %b + %c sparse(%idx : vector<2xi16>) { abid = 2 : i32, cbsz = 0 : i32 } : vector<8xi8>, vector<16xi8>, vector<4xi32>
+  func.return %d : vector<4xi32>
+}
+
+// -----
+
+func.func @sparse_mfma_abid_invalid_for_16bit(%a: vector<4xf16>, %b: vector<8xf16>, %c: vector<4xf32>, %idx: vector<4xi8>) -> vector<4xf32> {
+  // expected-error at +1 {{'amdgpu.sparse_mfma' op ABID must be between 0 and 3 for 16-bit source data}}
+  %d = amdgpu.sparse_mfma 16x16x32 %a * %b + %c sparse(%idx : vector<4xi8>) { abid = 4 : i32, cbsz = 0 : i32 } : vector<4xf16>, vector<8xf16>, vector<4xf32>
+  func.return %d : vector<4xf32>
+}
+
+// -----
+
+func.func @sparse_mfma_wrong_idx_type_for_8bit(%a: vector<8xi8>, %b: vector<16xi8>, %c: vector<4xi32>, %idx: vector<4xi8>) -> vector<4xi32> {
+  // expected-error at +1 {{'amdgpu.sparse_mfma' op expected vector<2xi16> sparse indices for 8-bit source data, but got}}
+  %d = amdgpu.sparse_mfma 16x16x64 %a * %b + %c sparse(%idx : vector<4xi8>) : vector<8xi8>, vector<16xi8>, vector<4xi32>
+  func.return %d : vector<4xi32>
+}
+
+// -----
+
+func.func @sparse_mfma_wrong_idx_type_for_16bit(%a: vector<4xf16>, %b: vector<8xf16>, %c: vector<4xf32>, %idx: vector<2xi16>) -> vector<4xf32> {
+  // expected-error at +1 {{'amdgpu.sparse_mfma' op expected vector<4xi8> sparse indices for 16-bit source data, but got}}
+  %d = amdgpu.sparse_mfma 16x16x32 %a * %b + %c sparse(%idx : vector<2xi16>) : vector<4xf16>, vector<8xf16>, vector<4xf32>
+  func.return %d : vector<4xf32>
+}
+
+// -----
+
+func.func @sparse_mfma_wrong_source_count(%a: vector<4xf16>, %b: vector<8xf16>, %c: vector<16xf32>, %idx: vector<4xi8>) -> vector<16xf32> {
+  // expected-error at +1 {{'amdgpu.sparse_mfma' op expected 16 source values for this operation but got 8}}
+  %d = amdgpu.sparse_mfma 32x32x32 %a * %b + %c sparse(%idx : vector<4xi8>) : vector<4xf16>, vector<8xf16>, vector<16xf32>
+  func.return %d : vector<16xf32>
+}
+
+// -----
+
+func.func @sparse_mfma_wrong_dest_count(%a: vector<4xf16>, %b: vector<8xf16>, %c: vector<16xf32>, %idx: vector<4xi8>) -> vector<16xf32> {
+  // expected-error at +1 {{'amdgpu.sparse_mfma' op expected 4 result values for this operation but got 16}}
+  %d = amdgpu.sparse_mfma 16x16x32 %a * %b + %c sparse(%idx : vector<4xi8>) : vector<4xf16>, vector<8xf16>, vector<16xf32>
+  func.return %d : vector<16xf32>
+}

>From 8f9416da45ea2830ead6c0624c8b0240b9e9c946 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Thu, 18 Dec 2025 14:25:58 -0800
Subject: [PATCH 10/11] nits

Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index a2bb755f40b16..8565a6b727fd1 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -1013,7 +1013,7 @@ def SMFMACOutTypes : AnyTypeOf<[
     VectorOfLengthAndType<[4, 16], [I32]>
 ]>;
 
-def SparseMFMAIdxTypes : AnyTypeOf<[
+def SMFMACIdxTypes : AnyTypeOf<[
     FixedVectorOfLengthAndType<[4], [I8]>,
     FixedVectorOfLengthAndType<[2], [I16]>
 ]>;
@@ -1176,7 +1176,7 @@ def AMDGPU_SparseMFMAOp :
                    SMFMACSparseInTypes:$sourceA,
                    SMFMACDenseInTypes:$sourceB,
                    SMFMACOutTypes:$destC,
-                   SparseMFMAIdxTypes:$sparseIdx,
+                   SMFMACIdxTypes:$sparseIdx,
                    DefaultValuedAttr<I32Attr, "0">:$cbsz,
                    DefaultValuedAttr<I32Attr, "0">:$abid)>,
     Results<(outs SMFMACOutTypes: $destD)> {

>From abf4e28e1319743da3fd8937bac11e5fcbc6681a Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Thu, 18 Dec 2025 14:34:56 -0800
Subject: [PATCH 11/11] nit

Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 31f87aaa3ce74..e77d131509add 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -679,14 +679,14 @@ LogicalResult SparseMFMAOp::verify() {
         !sparseIdxType.getElementType().isInteger(16))
       return emitOpError("expected vector<2xi16> sparse indices for 8-bit "
                          "source data, but got ")
-             << sparseIdxType;
+             << getSparseIdx().getType();
   } else {
     // 16-bit source data requires vector<4xi8> sparse indices.
     if (sparseIdxType.getNumElements() != 4 ||
         !sparseIdxType.getElementType().isInteger(8))
       return emitOpError("expected vector<4xi8> sparse indices for 16-bit "
                          "source data, but got ")
-             << sparseIdxType;
+             << getSparseIdx().getType();
   }
 
   int64_t expectedSourceElems = (getM() * getK()) / waveSize;



More information about the Mlir-commits mailing list