[Mlir-commits] [mlir] [WIP][AMDGPU] Added support for Sparse WMMA ops (PR #183360)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 10 10:28:15 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Ravil Dorozhinskii (ravil-mobile)

<details>
<summary>Changes</summary>

This PR adds support for Sparce WMMA ops (gfx12 and gfx1250)

---

Patch is 41.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/183360.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td (+98) 
- (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+257-15) 
- (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp (+71) 
- (added) mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx12.mlir (+85) 
- (added) mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx1250.mlir (+51) 
- (modified) mlir/test/Dialect/AMDGPU/invalid.mlir (+112) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
index bc88877247546..3eb039305904f 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
@@ -1070,6 +1070,8 @@ def AMDGPU_WMMAOp :
     The `clamp` flag is used to saturate the output of type T to `numeric_limits<T>::max()`
     in case of overflow.
 
+    The `wave64`attribute indicates whether an op is designed for 64 threads wavefont.
+
     Example:
     ```mlir
       %0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<8xf16>, vector<8xf16>, vector<8xf16>
@@ -1149,6 +1151,102 @@ def AMDGPU_SparseMFMAOp :
   let hasVerifier = 1;
 }
 
+// sparse_wmma (swmmac)
+def SWMMACSparseInTypes : AnyTypeOf<[
+    VectorOfLengthAndType<[4, 8, 16], [F16]>,
+    VectorOfLengthAndType<[4, 8, 16], [BF16]>,
+    VectorOfLengthAndType<[4, 8, 32], [I8]>,
+    VectorOfLengthAndType<[8, 16], [I<4>]>,
+    VectorOfLengthAndType<[4, 8, 16, 32], [F8E4M3FN, F8E5M2]>,
+    VectorOfLengthAndType<[4, 8, 16, 32], [F8E4M3FNUZ, F8E5M2FNUZ]>
+]>;
+
+def SWMMACDenseInTypes : AnyTypeOf<[
+    VectorOfLengthAndType<[8, 16, 32], [F16]>,
+    VectorOfLengthAndType<[8, 16, 32], [BF16]>,
+    VectorOfLengthAndType<[4, 8, 16, 64], [I8]>,
+    VectorOfLengthAndType<[8, 16, 32], [I<4>]>,
+    VectorOfLengthAndType<[4, 8, 16, 64], [F8E4M3FN, F8E5M2]>,
+    VectorOfLengthAndType<[4, 8, 16, 64], [F8E4M3FNUZ, F8E5M2FNUZ]>
+]>;
+
+def SWMMACOutTypes : AnyTypeOf<[
+    VectorOfLengthAndType<[4, 8, 16], [F32]>,
+    VectorOfLengthAndType<[4, 8], [F16]>,
+    VectorOfLengthAndType<[4, 8], [BF16]>,
+    VectorOfLengthAndType<[4, 8], [I32]>
+]>;
+
+def SWMMACIdxTypes : AnyTypeOf<[
+    FixedVectorOfLengthAndType<[4], [I8]>,
+]>;
+
+
+def AMDGPU_SparseWMMAOp :
+    AMDGPU_Op<"sparse_wmma", [AllTypesMatch<["destC", "destD"]>,
+                              Pure]>,
+    Arguments<(ins
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$m,
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n,
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[32, 64, 128]>]>:$k,
+                   SWMMACSparseInTypes:$sourceA,
+                   SWMMACDenseInTypes:$sourceB,
+                   SWMMACOutTypes:$destC,
+                   SWMMACIdxTypes:$sparseIdx,
+                   UnitAttr:$unsignedA,
+                   UnitAttr:$unsignedB,
+                   UnitAttr:$reuseA,
+                   UnitAttr:$reuseB,
+                   UnitAttr:$clamp,
+                   UnitAttr:$wave64)>,
+    Results<(outs SWMMACOutTypes: $destD)> {
+  let summary = "MLIR wrapper for CDNA sparse mfma (smfmac) instructions";
+  let description = [{
+    The `amdgpu.sparse_wmma` op is an MLIR wrapper around intrinsics for various
+    `swmmac` instructions in the AMDGPU architecture, which perform matrix
+    multiply-accumulate operations using 2:4 structured sparsity on matrix A
+    with dense matrices B, C, and D.
+
+    On gfx12, swmmac intrinsics support:
+      - M=N=16, K=32 and M=N=32, K=16 for f16, bf16, i8 and i4 sources
+      - M=N=16, K=64 for i4 sources
+
+    On gfx1250, swmmac intrinsics additionally support:
+      - M=N=16, K=64 for f16 and bf16 sources
+      - M=N=16, K=128 for f16, bf16 and i8 sources
+
+    The `sparseIdx` parameter contains packed indices identifying the positions
+    of non-zero elements in the 2:4 sparse matrix A. For 16-bit source data,
+    use `vector<4xi8>` (four 8-bit indices). For 8-bit source data, use
+    `vector<2xi16>` (two 16-bit indices).
+
+    `unsignedA` and `unsignedB` flag that the `int8` LLVM inputs are unsigned.
+
+    The `clamp` flag is used to saturate the output of type T to `numeric_limits<T>::max()`
+    in case of overflow.
+
+    Example:
+    ```mlir
+      %0 = amdgpu.sparse_wmma 16x16x32 %matA * %matB + %matC sparse(%idx : vector<4xi8>)
+        : vector<4xf16>, vector<8xf16>, vector<4xf32>
+
+      %1 = amdgpu.sparse_wmma 16x16x64 %matA * %matB + %matC sparse(%idx : vector<2xi16>)
+        : vector<8xi8>, vector<16xi8>, vector<4xi32>
+
+      %2 = amdgpu.sparse_wmma 16x16x64 %matA * %matB + %matC sparse(%idx : vector<2xi16>)
+        { unsignedA = 0 : i1, unsignedB = 1 : i1, clamp = 0 : i1 }
+        : vector<8xf8E4M3FNUZ>, vector<16xf8E4M3FNUZ>, vector<4xf32>
+    ```
+  }];
+  let assemblyFormat = [{
+    custom<MNKDimensionList>($m, $n, $k) $sourceA `*` $sourceB `+` $destC
+    `sparse` `(` $sparseIdx `:` type($sparseIdx) `)`
+    attr-dict
+    `:` type($sourceA) `,` type($sourceB) `,` type($destC)
+  }];
+  let hasVerifier = 1;
+}
+
 def AMDGPU_GatherToLDSOp :
     AMDGPU_Op<"gather_to_lds", [AttrSizedOperandSegments]>,
     Arguments<(ins
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 379d6180596e9..47387e8ebde3c 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -680,10 +680,11 @@ static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter,
   return input;
 }
 
-/// Converts sparse MFMA (smfmac) operands to the expected ROCDL types.
-static Value convertSparseMFMAVectorOperand(ConversionPatternRewriter &rewriter,
-                                            Location loc, Value input,
-                                            bool allowBf16 = true) {
+/// Converts sparse MFMA/WMMA (smfmac/swmmac) operands to the expected ROCDL
+/// types.
+static Value convertSparseVectorOperand(ConversionPatternRewriter &rewriter,
+                                        Location loc, Value input,
+                                        bool allowBf16 = true) {
   Type inputType = input.getType();
   auto vectorType = cast<VectorType>(inputType);
   // bf16 -> i16 when not allowed (pre-gfx950).
@@ -695,8 +696,10 @@ static Value convertSparseMFMAVectorOperand(ConversionPatternRewriter &rewriter,
       vectorType.getElementTypeBitWidth() <= 8) {
     int64_t numWords = llvm::divideCeil(
         vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32);
-    return LLVM::BitcastOp::create(
-        rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()), input);
+    Type castType = (numWords > 1)
+                        ? Type{VectorType::get(numWords, rewriter.getI32Type())}
+                        : rewriter.getI32Type();
+    return LLVM::BitcastOp::create(rewriter, loc, castType, input);
   }
   return input;
 }
@@ -1339,6 +1342,162 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
   return std::nullopt;
 }
 
+/// Returns the `rocdl` intrinsic corresponding to a SparseWMMA operation
+/// `swmmac` if one exists. This includes checking to ensure the intrinsic is
+/// supported on the architecture you are compiling for.
+struct SparseWMMAOpInfo {
+  StringRef name;
+  bool useSign;
+  bool useReuse;
+  bool useClamp;
+};
+
+static std::optional<SparseWMMAOpInfo>
+sparseWMMAOpToIntrinsic(SparseWMMAOp swmmac, Chipset chipset) {
+
+  Type sourceAElem = getElementTypeOrSelf(swmmac.getSourceA().getType());
+  Type sourceBElem = getElementTypeOrSelf(swmmac.getSourceB().getType());
+  Type destElem = getElementTypeOrSelf(swmmac.getDestC().getType());
+
+  uint32_t m = swmmac.getM(), n = swmmac.getN(), k = swmmac.getK();
+
+  if ((m != 16) || (n != 16))
+    return std::nullopt;
+
+  const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
+  if (isRDNA4) {
+    if (k == 32) {
+      if (destElem.isF32() && sourceAElem.isF16() && sourceBElem.isF16())
+        return SparseWMMAOpInfo{
+            ROCDL::swmmac_f32_16x16x32_f16::getOperationName(), false, false,
+            false};
+      if (destElem.isF32() && sourceAElem.isBF16() && sourceBElem.isBF16())
+        return SparseWMMAOpInfo{
+            ROCDL::swmmac_f32_16x16x32_bf16::getOperationName(), false, false,
+            false};
+      if (destElem.isF16() && sourceAElem.isF16() && sourceBElem.isF16())
+        return SparseWMMAOpInfo{
+            ROCDL::swmmac_f16_16x16x32_f16::getOperationName(), false, false,
+            false};
+      if (destElem.isBF16() && sourceAElem.isBF16() && sourceBElem.isBF16())
+        return SparseWMMAOpInfo{
+            ROCDL::swmmac_bf16_16x16x32_bf16::getOperationName(), false, false,
+            false};
+      if (destElem.isInteger(32) && sourceAElem.isInteger(8) &&
+          sourceBElem.isInteger(8))
+        return SparseWMMAOpInfo{
+            ROCDL::swmmac_i32_16x16x32_iu8::getOperationName(), true, false,
+            true};
+      if (destElem.isInteger(32) && sourceAElem.isInteger(4) &&
+          sourceBElem.isInteger(4))
+        return SparseWMMAOpInfo{
+            ROCDL::swmmac_i32_16x16x32_iu4::getOperationName(), true, false,
+            true};
+      if (destElem.isF32() && sourceAElem.isF8E4M3FN() &&
+          sourceBElem.isF8E4M3FN())
+        return SparseWMMAOpInfo{
+            ROCDL::swmmac_f32_16x16x32_fp8_fp8::getOperationName(), false,
+            false, false};
+      if (destElem.isF32() && sourceAElem.isF8E4M3FN() &&
+          sourceBElem.isF8E5M2())
+        return SparseWMMAOpInfo{
+            ROCDL::swmmac_f32_16x16x32_fp8_bf8::getOperationName(), false,
+            false, false};
+      if (destElem.isF32() && sourceAElem.isF8E5M2() &&
+          sourceBElem.isF8E4M3FN())
+        return SparseWMMAOpInfo{
+            ROCDL::swmmac_f32_16x16x32_bf8_fp8::getOperationName(), false,
+            false, false};
+      if (destElem.isF32() && sourceAElem.isF8E5M2() && sourceBElem.isF8E5M2())
+        return SparseWMMAOpInfo{
+            ROCDL::swmmac_f32_16x16x32_bf8_bf8::getOperationName(), false,
+            false, false};
+    }
+    if (k == 64) {
+      if (destElem.isInteger(32) && sourceAElem.isInteger(4) &&
+          sourceBElem.isInteger(4))
+        return SparseWMMAOpInfo{
+            ROCDL::swmmac_i32_16x16x64_iu4::getOperationName(), true, false,
+            true};
+    }
+  }
+
+  const bool isGFX1250 = chipset == kGfx1250;
+  const bool isWavesize64 = swmmac.getWave64();
+  if (isGFX1250 && !isWavesize64) {
+    if (k == 64) {
+      if (destElem.isF32() && sourceAElem.isF16() && sourceBElem.isF16())
+        return SparseWMMAOpInfo{
+            ROCDL::swmmac_f32_16x16x64_f16::getOperationName(), true, true,
+            false};
+      if (destElem.isF32() && sourceAElem.isBF16() && sourceBElem.isBF16())
+        return SparseWMMAOpInfo{
+            ROCDL::swmmac_f32_16x16x64_bf16::getOperationName(), true, true,
+            false};
+      if (destElem.isF16() && sourceAElem.isF16() && sourceBElem.isF16())
+        return SparseWMMAOpInfo{
+            ROCDL::swmmac_f16_16x16x64_f16::getOperationName(), true, true,
+            false};
+      if (destElem.isBF16() && sourceAElem.isBF16() && sourceBElem.isBF16())
+        return SparseWMMAOpInfo{
+            ROCDL::swmmac_bf16_16x16x64_bf16::getOperationName(), true, true,
+            false};
+    }
+    if (k == 128) {
+      if (destElem.isF32() && sourceAElem.isF8E4M3FN() &&
+          sourceBElem.isF8E4M3FN())
+        return SparseWMMAOpInfo{
+            ROCDL::swmmac_f32_16x16x128_fp8_fp8::getOperationName(), false,
+            true, false};
+      if (destElem.isF32() && sourceAElem.isF8E4M3FN() &&
+          sourceBElem.isF8E5M2())
+        return SparseWMMAOpInfo{
+            ROCDL::swmmac_f32_16x16x128_fp8_bf8::getOperationName(), false,
+            true, false};
+      if (destElem.isF32() && sourceAElem.isF8E5M2() &&
+          sourceBElem.isF8E4M3FN())
+        return SparseWMMAOpInfo{
+            ROCDL::swmmac_f32_16x16x128_bf8_fp8::getOperationName(), false,
+            true, false};
+      if (destElem.isF32() && sourceAElem.isF8E5M2() && sourceBElem.isF8E5M2())
+        return SparseWMMAOpInfo{
+            ROCDL::swmmac_f32_16x16x128_bf8_bf8::getOperationName(), false,
+            true, false};
+      if (destElem.isF16() && sourceAElem.isF8E4M3FN() &&
+          sourceBElem.isF8E4M3FN())
+        return SparseWMMAOpInfo{
+            ROCDL::swmmac_f16_16x16x128_fp8_fp8::getOperationName(), false,
+            true, false};
+      if (destElem.isF16() && sourceAElem.isF8E4M3FN() &&
+          sourceBElem.isF8E5M2())
+        return SparseWMMAOpInfo{
+            ROCDL::swmmac_f16_16x16x128_fp8_bf8::getOperationName(), false,
+            true, false};
+      if (destElem.isF16() && sourceAElem.isF8E5M2() &&
+          sourceBElem.isF8E4M3FN())
+        return SparseWMMAOpInfo{
+            ROCDL::swmmac_f16_16x16x128_bf8_fp8::getOperationName(), false,
+            true, false};
+      if (destElem.isF16() && sourceAElem.isF8E5M2() && sourceBElem.isF8E5M2())
+        return SparseWMMAOpInfo{
+            ROCDL::swmmac_f16_16x16x128_bf8_bf8::getOperationName(), false,
+            true, false};
+      if (destElem.isF16() && sourceAElem.isInteger(8) &&
+          sourceBElem.isInteger(8))
+        return SparseWMMAOpInfo{
+            ROCDL::swmmac_f16_16x16x128_bf8_bf8::getOperationName(), false,
+            true, false};
+      if (destElem.isInteger(32) && sourceAElem.isInteger(8) &&
+          sourceBElem.isInteger(8))
+        return SparseWMMAOpInfo{
+            ROCDL::swmmac_i32_16x16x128_iu8::getOperationName(), true, true,
+            true};
+    }
+  }
+
+  return std::nullopt;
+}
+
 namespace {
 struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
   MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
@@ -1485,10 +1644,10 @@ struct SparseMFMAOpLowering : public ConvertOpToLLVMPattern<SparseMFMAOp> {
       return op->emitOpError("sparse MFMA (smfmac) only supported on gfx942+");
     bool isGfx950 = chipset >= kGfx950;
 
-    Value a = convertSparseMFMAVectorOperand(rewriter, loc,
-                                             adaptor.getSourceA(), isGfx950);
-    Value b = convertSparseMFMAVectorOperand(rewriter, loc,
-                                             adaptor.getSourceB(), isGfx950);
+    Value a = convertSparseVectorOperand(rewriter, loc, adaptor.getSourceA(),
+                                         isGfx950);
+    Value b = convertSparseVectorOperand(rewriter, loc, adaptor.getSourceB(),
+                                         isGfx950);
     Value c = adaptor.getDestC();
 
     std::optional<StringRef> maybeIntrinsic = smfmacOpToIntrinsic(op, chipset);
@@ -1592,6 +1751,88 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   }
 };
 
+struct SparseWMMAOpLowering : public ConvertOpToLLVMPattern<SparseWMMAOp> {
+  SparseWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<SparseWMMAOp>(converter), chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(SparseWMMAOp op, SparseWMMAOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto outType =
+        typeConverter->convertType<VectorType>(op.getDestD().getType());
+    if (!outType)
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+    // TODO (Ravil)
+    std::optional<SparseWMMAOpInfo> maybeIntrinsic =
+        sparseWMMAOpToIntrinsic(op, chipset);
+
+    if (!maybeIntrinsic.has_value())
+      return op.emitOpError(
+          "no intrinsic matching Sparse WMMA on the given chipset");
+    SparseWMMAOpInfo intrinsic = maybeIntrinsic.value();
+
+    SmallVector<NamedAttribute> attrs;
+
+    if ((op.getUnsignedA() || op.getUnsignedB()) && !intrinsic.useSign)
+      return op->emitOpError("intrinsic doesn't support unsign");
+    if (intrinsic.useSign) {
+      if (auto attr = op.getUnsignedAAttr())
+        attrs.push_back({"signA", attr});
+      if (auto attr = op.getUnsignedBAttr())
+        attrs.push_back({"signB", attr});
+    }
+
+    if ((op.getReuseA() || op.getReuseB()) && !intrinsic.useReuse)
+      return op->emitOpError("intrinsic doesn't support reuse");
+    if (intrinsic.useReuse) {
+      if (auto attr = op.getReuseAAttr())
+        attrs.push_back({"reuseA", attr});
+      if (auto attr = op.getReuseBAttr())
+        attrs.push_back({"reuseB", attr});
+    }
+
+    if (op.getClamp() && !intrinsic.useClamp)
+      return op->emitOpError("intrinsic doesn't support clamp");
+    if (intrinsic.useClamp && op.getClampAttr())
+      attrs.push_back({"clamp", op.getClampAttr()});
+
+    const bool isGFX1250orHigher =
+        chipset.majorVersion == 12 && chipset.minorVersion >= 5;
+    Value a = convertSparseVectorOperand(rewriter, loc, adaptor.getSourceA(),
+                                         isGFX1250orHigher);
+    Value b = convertSparseVectorOperand(rewriter, loc, adaptor.getSourceB(),
+                                         isGFX1250orHigher);
+    Value c = adaptor.getDestC();
+    VectorType rawOutType = outType;
+    if (!isGFX1250orHigher) {
+      c = convertSparseVectorOperand(rewriter, loc, adaptor.getDestC(), false);
+      rawOutType = cast<VectorType>(c.getType());
+    }
+
+    // Bitcast sparse indices from vector<4xi8> to i32.
+    Value sparseIdx = LLVM::BitcastOp::create(
+        rewriter, loc, rewriter.getI32Type(), adaptor.getSparseIdx());
+
+    OperationState loweredOp(loc, intrinsic.name);
+    loweredOp.addTypes(rawOutType);
+    loweredOp.addOperands({a, b, c, sparseIdx});
+    loweredOp.addAttributes(attrs);
+    Operation *lowered = rewriter.create(loweredOp);
+
+    Operation *maybeCastBack = lowered;
+    if (rawOutType != outType)
+      maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
+                                              lowered->getResult(0));
+    rewriter.replaceOp(op, maybeCastBack->getResults());
+
+    return success();
+  }
+};
+
 struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
   ScaledWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
       : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
@@ -3833,11 +4074,12 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
            AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
            SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
            SparseMFMAOpLowering, WMMAOpLowering, ScaledWMMAOpLowering,
-           ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering,
-           ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
-           PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
-           GatherToLDSOpLowering, TransposeLoadOpLowering,
-           AMDGPUPermlaneLowering, AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
+           SparseWMMAOpLowering, ExtPackedFp8OpLowering,
+           ScaledExtPackedMatrixOpLowering, ScaledExtPackedOpLowering,
+           PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
+           PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
+           TransposeLoadOpLowering, AMDGPUPermlaneLowering,
+           AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
            AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
            AMDGPULowerDescriptor<MakeDmaDescriptorOp>,
            AMDGPULowerDescriptor<MakeGatherDmaDescriptorOp>,
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp
index f452d2de15dc8..b715f4ab93231 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp
@@ -670,6 +670,77 @@ LogicalResult SparseMFMAOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// SparseWMMAOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult SparseWMMAOp::verify() {
+  auto sparseType = cast<VectorType>(getSourceA().getType());
+  auto denseType = cast<VectorType>(getSourceB().getType());
+  auto destType = cast<VectorType>(getDestC().getType());
+
+  Type sparseElem = sparseType.getElementType();
+  Type denseElem = denseType.getElementType();
+  Type destElem = destType.getElementType();
+  int64_t sparseLen = sparseType.getNumElements();
+  int64_t denseLen = denseType.getNumElements();
+  int64_t destLen = destType.getNumElements();
+
+  uint32_t m = getM(), n = getN(), k = getK();
+  if ((m != 16) || (n != 16))
+    return emitOpError("expected MxN to be exactly 16x16");
+
+  const bool isWavesize64 = getWave64();
+ ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/183360


More information about the Mlir-commits mailing list