[Mlir-commits] [mlir] [WIP][AMDGPU] Added support for Sparse WMMA ops (PR #183360)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 10 10:28:15 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Ravil Dorozhinskii (ravil-mobile)
<details>
<summary>Changes</summary>
This PR adds support for Sparce WMMA ops (gfx12 and gfx1250)
---
Patch is 41.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/183360.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td (+98)
- (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+257-15)
- (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp (+71)
- (added) mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx12.mlir (+85)
- (added) mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx1250.mlir (+51)
- (modified) mlir/test/Dialect/AMDGPU/invalid.mlir (+112)
``````````diff
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
index bc88877247546..3eb039305904f 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
@@ -1070,6 +1070,8 @@ def AMDGPU_WMMAOp :
The `clamp` flag is used to saturate the output of type T to `numeric_limits<T>::max()`
in case of overflow.
+ The `wave64`attribute indicates whether an op is designed for 64 threads wavefont.
+
Example:
```mlir
%0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<8xf16>, vector<8xf16>, vector<8xf16>
@@ -1149,6 +1151,102 @@ def AMDGPU_SparseMFMAOp :
let hasVerifier = 1;
}
+// sparse_wmma (swmmac)
+def SWMMACSparseInTypes : AnyTypeOf<[
+ VectorOfLengthAndType<[4, 8, 16], [F16]>,
+ VectorOfLengthAndType<[4, 8, 16], [BF16]>,
+ VectorOfLengthAndType<[4, 8, 32], [I8]>,
+ VectorOfLengthAndType<[8, 16], [I<4>]>,
+ VectorOfLengthAndType<[4, 8, 16, 32], [F8E4M3FN, F8E5M2]>,
+ VectorOfLengthAndType<[4, 8, 16, 32], [F8E4M3FNUZ, F8E5M2FNUZ]>
+]>;
+
+def SWMMACDenseInTypes : AnyTypeOf<[
+ VectorOfLengthAndType<[8, 16, 32], [F16]>,
+ VectorOfLengthAndType<[8, 16, 32], [BF16]>,
+ VectorOfLengthAndType<[4, 8, 16, 64], [I8]>,
+ VectorOfLengthAndType<[8, 16, 32], [I<4>]>,
+ VectorOfLengthAndType<[4, 8, 16, 64], [F8E4M3FN, F8E5M2]>,
+ VectorOfLengthAndType<[4, 8, 16, 64], [F8E4M3FNUZ, F8E5M2FNUZ]>
+]>;
+
+def SWMMACOutTypes : AnyTypeOf<[
+ VectorOfLengthAndType<[4, 8, 16], [F32]>,
+ VectorOfLengthAndType<[4, 8], [F16]>,
+ VectorOfLengthAndType<[4, 8], [BF16]>,
+ VectorOfLengthAndType<[4, 8], [I32]>
+]>;
+
+def SWMMACIdxTypes : AnyTypeOf<[
+ FixedVectorOfLengthAndType<[4], [I8]>,
+]>;
+
+
+def AMDGPU_SparseWMMAOp :
+ AMDGPU_Op<"sparse_wmma", [AllTypesMatch<["destC", "destD"]>,
+ Pure]>,
+ Arguments<(ins
+ ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$m,
+ ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n,
+ ConfinedAttr<I32Attr, [IntIsOneOf<[32, 64, 128]>]>:$k,
+ SWMMACSparseInTypes:$sourceA,
+ SWMMACDenseInTypes:$sourceB,
+ SWMMACOutTypes:$destC,
+ SWMMACIdxTypes:$sparseIdx,
+ UnitAttr:$unsignedA,
+ UnitAttr:$unsignedB,
+ UnitAttr:$reuseA,
+ UnitAttr:$reuseB,
+ UnitAttr:$clamp,
+ UnitAttr:$wave64)>,
+ Results<(outs SWMMACOutTypes: $destD)> {
+ let summary = "MLIR wrapper for CDNA sparse mfma (smfmac) instructions";
+ let description = [{
+ The `amdgpu.sparse_wmma` op is an MLIR wrapper around intrinsics for various
+ `swmmac` instructions in the AMDGPU architecture, which perform matrix
+ multiply-accumulate operations using 2:4 structured sparsity on matrix A
+ with dense matrices B, C, and D.
+
+ On gfx12, swmmac intrinsics support:
+ - M=N=16, K=32 and M=N=32, K=16 for f16, bf16, i8 and i4 sources
+ - M=N=16, K=64 for i4 sources
+
+ On gfx1250, swmmac intrinsics additionally support:
+ - M=N=16, K=64 for f16 and bf16 sources
+ - M=N=16, K=128 for f16, bf16 and i8 sources
+
+ The `sparseIdx` parameter contains packed indices identifying the positions
+ of non-zero elements in the 2:4 sparse matrix A. For 16-bit source data,
+ use `vector<4xi8>` (four 8-bit indices). For 8-bit source data, use
+ `vector<2xi16>` (two 16-bit indices).
+
+ `unsignedA` and `unsignedB` flag that the `int8` LLVM inputs are unsigned.
+
+ The `clamp` flag is used to saturate the output of type T to `numeric_limits<T>::max()`
+ in case of overflow.
+
+ Example:
+ ```mlir
+ %0 = amdgpu.sparse_wmma 16x16x32 %matA * %matB + %matC sparse(%idx : vector<4xi8>)
+ : vector<4xf16>, vector<8xf16>, vector<4xf32>
+
+ %1 = amdgpu.sparse_wmma 16x16x64 %matA * %matB + %matC sparse(%idx : vector<2xi16>)
+ : vector<8xi8>, vector<16xi8>, vector<4xi32>
+
+ %2 = amdgpu.sparse_wmma 16x16x64 %matA * %matB + %matC sparse(%idx : vector<2xi16>)
+ { unsignedA = 0 : i1, unsignedB = 1 : i1, clamp = 0 : i1 }
+ : vector<8xf8E4M3FNUZ>, vector<16xf8E4M3FNUZ>, vector<4xf32>
+ ```
+ }];
+ let assemblyFormat = [{
+ custom<MNKDimensionList>($m, $n, $k) $sourceA `*` $sourceB `+` $destC
+ `sparse` `(` $sparseIdx `:` type($sparseIdx) `)`
+ attr-dict
+ `:` type($sourceA) `,` type($sourceB) `,` type($destC)
+ }];
+ let hasVerifier = 1;
+}
+
def AMDGPU_GatherToLDSOp :
AMDGPU_Op<"gather_to_lds", [AttrSizedOperandSegments]>,
Arguments<(ins
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 379d6180596e9..47387e8ebde3c 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -680,10 +680,11 @@ static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter,
return input;
}
-/// Converts sparse MFMA (smfmac) operands to the expected ROCDL types.
-static Value convertSparseMFMAVectorOperand(ConversionPatternRewriter &rewriter,
- Location loc, Value input,
- bool allowBf16 = true) {
+/// Converts sparse MFMA/WMMA (smfmac/swmmac) operands to the expected ROCDL
+/// types.
+static Value convertSparseVectorOperand(ConversionPatternRewriter &rewriter,
+ Location loc, Value input,
+ bool allowBf16 = true) {
Type inputType = input.getType();
auto vectorType = cast<VectorType>(inputType);
// bf16 -> i16 when not allowed (pre-gfx950).
@@ -695,8 +696,10 @@ static Value convertSparseMFMAVectorOperand(ConversionPatternRewriter &rewriter,
vectorType.getElementTypeBitWidth() <= 8) {
int64_t numWords = llvm::divideCeil(
vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32);
- return LLVM::BitcastOp::create(
- rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()), input);
+ Type castType = (numWords > 1)
+ ? Type{VectorType::get(numWords, rewriter.getI32Type())}
+ : rewriter.getI32Type();
+ return LLVM::BitcastOp::create(rewriter, loc, castType, input);
}
return input;
}
@@ -1339,6 +1342,162 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
return std::nullopt;
}
+/// Returns the `rocdl` intrinsic corresponding to a SparseWMMA operation
+/// `swmmac` if one exists. This includes checking to ensure the intrinsic is
+/// supported on the architecture you are compiling for.
+struct SparseWMMAOpInfo {
+ StringRef name;
+ bool useSign;
+ bool useReuse;
+ bool useClamp;
+};
+
+static std::optional<SparseWMMAOpInfo>
+sparseWMMAOpToIntrinsic(SparseWMMAOp swmmac, Chipset chipset) {
+
+ Type sourceAElem = getElementTypeOrSelf(swmmac.getSourceA().getType());
+ Type sourceBElem = getElementTypeOrSelf(swmmac.getSourceB().getType());
+ Type destElem = getElementTypeOrSelf(swmmac.getDestC().getType());
+
+ uint32_t m = swmmac.getM(), n = swmmac.getN(), k = swmmac.getK();
+
+ if ((m != 16) || (n != 16))
+ return std::nullopt;
+
+ const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
+ if (isRDNA4) {
+ if (k == 32) {
+ if (destElem.isF32() && sourceAElem.isF16() && sourceBElem.isF16())
+ return SparseWMMAOpInfo{
+ ROCDL::swmmac_f32_16x16x32_f16::getOperationName(), false, false,
+ false};
+ if (destElem.isF32() && sourceAElem.isBF16() && sourceBElem.isBF16())
+ return SparseWMMAOpInfo{
+ ROCDL::swmmac_f32_16x16x32_bf16::getOperationName(), false, false,
+ false};
+ if (destElem.isF16() && sourceAElem.isF16() && sourceBElem.isF16())
+ return SparseWMMAOpInfo{
+ ROCDL::swmmac_f16_16x16x32_f16::getOperationName(), false, false,
+ false};
+ if (destElem.isBF16() && sourceAElem.isBF16() && sourceBElem.isBF16())
+ return SparseWMMAOpInfo{
+ ROCDL::swmmac_bf16_16x16x32_bf16::getOperationName(), false, false,
+ false};
+ if (destElem.isInteger(32) && sourceAElem.isInteger(8) &&
+ sourceBElem.isInteger(8))
+ return SparseWMMAOpInfo{
+ ROCDL::swmmac_i32_16x16x32_iu8::getOperationName(), true, false,
+ true};
+ if (destElem.isInteger(32) && sourceAElem.isInteger(4) &&
+ sourceBElem.isInteger(4))
+ return SparseWMMAOpInfo{
+ ROCDL::swmmac_i32_16x16x32_iu4::getOperationName(), true, false,
+ true};
+ if (destElem.isF32() && sourceAElem.isF8E4M3FN() &&
+ sourceBElem.isF8E4M3FN())
+ return SparseWMMAOpInfo{
+ ROCDL::swmmac_f32_16x16x32_fp8_fp8::getOperationName(), false,
+ false, false};
+ if (destElem.isF32() && sourceAElem.isF8E4M3FN() &&
+ sourceBElem.isF8E5M2())
+ return SparseWMMAOpInfo{
+ ROCDL::swmmac_f32_16x16x32_fp8_bf8::getOperationName(), false,
+ false, false};
+ if (destElem.isF32() && sourceAElem.isF8E5M2() &&
+ sourceBElem.isF8E4M3FN())
+ return SparseWMMAOpInfo{
+ ROCDL::swmmac_f32_16x16x32_bf8_fp8::getOperationName(), false,
+ false, false};
+ if (destElem.isF32() && sourceAElem.isF8E5M2() && sourceBElem.isF8E5M2())
+ return SparseWMMAOpInfo{
+ ROCDL::swmmac_f32_16x16x32_bf8_bf8::getOperationName(), false,
+ false, false};
+ }
+ if (k == 64) {
+ if (destElem.isInteger(32) && sourceAElem.isInteger(4) &&
+ sourceBElem.isInteger(4))
+ return SparseWMMAOpInfo{
+ ROCDL::swmmac_i32_16x16x64_iu4::getOperationName(), true, false,
+ true};
+ }
+ }
+
+ const bool isGFX1250 = chipset == kGfx1250;
+ const bool isWavesize64 = swmmac.getWave64();
+ if (isGFX1250 && !isWavesize64) {
+ if (k == 64) {
+ if (destElem.isF32() && sourceAElem.isF16() && sourceBElem.isF16())
+ return SparseWMMAOpInfo{
+ ROCDL::swmmac_f32_16x16x64_f16::getOperationName(), true, true,
+ false};
+ if (destElem.isF32() && sourceAElem.isBF16() && sourceBElem.isBF16())
+ return SparseWMMAOpInfo{
+ ROCDL::swmmac_f32_16x16x64_bf16::getOperationName(), true, true,
+ false};
+ if (destElem.isF16() && sourceAElem.isF16() && sourceBElem.isF16())
+ return SparseWMMAOpInfo{
+ ROCDL::swmmac_f16_16x16x64_f16::getOperationName(), true, true,
+ false};
+ if (destElem.isBF16() && sourceAElem.isBF16() && sourceBElem.isBF16())
+ return SparseWMMAOpInfo{
+ ROCDL::swmmac_bf16_16x16x64_bf16::getOperationName(), true, true,
+ false};
+ }
+ if (k == 128) {
+ if (destElem.isF32() && sourceAElem.isF8E4M3FN() &&
+ sourceBElem.isF8E4M3FN())
+ return SparseWMMAOpInfo{
+ ROCDL::swmmac_f32_16x16x128_fp8_fp8::getOperationName(), false,
+ true, false};
+ if (destElem.isF32() && sourceAElem.isF8E4M3FN() &&
+ sourceBElem.isF8E5M2())
+ return SparseWMMAOpInfo{
+ ROCDL::swmmac_f32_16x16x128_fp8_bf8::getOperationName(), false,
+ true, false};
+ if (destElem.isF32() && sourceAElem.isF8E5M2() &&
+ sourceBElem.isF8E4M3FN())
+ return SparseWMMAOpInfo{
+ ROCDL::swmmac_f32_16x16x128_bf8_fp8::getOperationName(), false,
+ true, false};
+ if (destElem.isF32() && sourceAElem.isF8E5M2() && sourceBElem.isF8E5M2())
+ return SparseWMMAOpInfo{
+ ROCDL::swmmac_f32_16x16x128_bf8_bf8::getOperationName(), false,
+ true, false};
+ if (destElem.isF16() && sourceAElem.isF8E4M3FN() &&
+ sourceBElem.isF8E4M3FN())
+ return SparseWMMAOpInfo{
+ ROCDL::swmmac_f16_16x16x128_fp8_fp8::getOperationName(), false,
+ true, false};
+ if (destElem.isF16() && sourceAElem.isF8E4M3FN() &&
+ sourceBElem.isF8E5M2())
+ return SparseWMMAOpInfo{
+ ROCDL::swmmac_f16_16x16x128_fp8_bf8::getOperationName(), false,
+ true, false};
+ if (destElem.isF16() && sourceAElem.isF8E5M2() &&
+ sourceBElem.isF8E4M3FN())
+ return SparseWMMAOpInfo{
+ ROCDL::swmmac_f16_16x16x128_bf8_fp8::getOperationName(), false,
+ true, false};
+ if (destElem.isF16() && sourceAElem.isF8E5M2() && sourceBElem.isF8E5M2())
+ return SparseWMMAOpInfo{
+ ROCDL::swmmac_f16_16x16x128_bf8_bf8::getOperationName(), false,
+ true, false};
+ if (destElem.isF16() && sourceAElem.isInteger(8) &&
+ sourceBElem.isInteger(8))
+ return SparseWMMAOpInfo{
+ ROCDL::swmmac_f16_16x16x128_bf8_bf8::getOperationName(), false,
+ true, false};
+ if (destElem.isInteger(32) && sourceAElem.isInteger(8) &&
+ sourceBElem.isInteger(8))
+ return SparseWMMAOpInfo{
+ ROCDL::swmmac_i32_16x16x128_iu8::getOperationName(), true, true,
+ true};
+ }
+ }
+
+ return std::nullopt;
+}
+
namespace {
struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
@@ -1485,10 +1644,10 @@ struct SparseMFMAOpLowering : public ConvertOpToLLVMPattern<SparseMFMAOp> {
return op->emitOpError("sparse MFMA (smfmac) only supported on gfx942+");
bool isGfx950 = chipset >= kGfx950;
- Value a = convertSparseMFMAVectorOperand(rewriter, loc,
- adaptor.getSourceA(), isGfx950);
- Value b = convertSparseMFMAVectorOperand(rewriter, loc,
- adaptor.getSourceB(), isGfx950);
+ Value a = convertSparseVectorOperand(rewriter, loc, adaptor.getSourceA(),
+ isGfx950);
+ Value b = convertSparseVectorOperand(rewriter, loc, adaptor.getSourceB(),
+ isGfx950);
Value c = adaptor.getDestC();
std::optional<StringRef> maybeIntrinsic = smfmacOpToIntrinsic(op, chipset);
@@ -1592,6 +1751,88 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
}
};
+struct SparseWMMAOpLowering : public ConvertOpToLLVMPattern<SparseWMMAOp> {
+ SparseWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<SparseWMMAOp>(converter), chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(SparseWMMAOp op, SparseWMMAOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto outType =
+ typeConverter->convertType<VectorType>(op.getDestD().getType());
+ if (!outType)
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+ // TODO (Ravil)
+ std::optional<SparseWMMAOpInfo> maybeIntrinsic =
+ sparseWMMAOpToIntrinsic(op, chipset);
+
+ if (!maybeIntrinsic.has_value())
+ return op.emitOpError(
+ "no intrinsic matching Sparse WMMA on the given chipset");
+ SparseWMMAOpInfo intrinsic = maybeIntrinsic.value();
+
+ SmallVector<NamedAttribute> attrs;
+
+ if ((op.getUnsignedA() || op.getUnsignedB()) && !intrinsic.useSign)
+ return op->emitOpError("intrinsic doesn't support unsign");
+ if (intrinsic.useSign) {
+ if (auto attr = op.getUnsignedAAttr())
+ attrs.push_back({"signA", attr});
+ if (auto attr = op.getUnsignedBAttr())
+ attrs.push_back({"signB", attr});
+ }
+
+ if ((op.getReuseA() || op.getReuseB()) && !intrinsic.useReuse)
+ return op->emitOpError("intrinsic doesn't support reuse");
+ if (intrinsic.useReuse) {
+ if (auto attr = op.getReuseAAttr())
+ attrs.push_back({"reuseA", attr});
+ if (auto attr = op.getReuseBAttr())
+ attrs.push_back({"reuseB", attr});
+ }
+
+ if (op.getClamp() && !intrinsic.useClamp)
+ return op->emitOpError("intrinsic doesn't support clamp");
+ if (intrinsic.useClamp && op.getClampAttr())
+ attrs.push_back({"clamp", op.getClampAttr()});
+
+ const bool isGFX1250orHigher =
+ chipset.majorVersion == 12 && chipset.minorVersion >= 5;
+ Value a = convertSparseVectorOperand(rewriter, loc, adaptor.getSourceA(),
+ isGFX1250orHigher);
+ Value b = convertSparseVectorOperand(rewriter, loc, adaptor.getSourceB(),
+ isGFX1250orHigher);
+ Value c = adaptor.getDestC();
+ VectorType rawOutType = outType;
+ if (!isGFX1250orHigher) {
+ c = convertSparseVectorOperand(rewriter, loc, adaptor.getDestC(), false);
+ rawOutType = cast<VectorType>(c.getType());
+ }
+
+ // Bitcast sparse indices from vector<4xi8> to i32.
+ Value sparseIdx = LLVM::BitcastOp::create(
+ rewriter, loc, rewriter.getI32Type(), adaptor.getSparseIdx());
+
+ OperationState loweredOp(loc, intrinsic.name);
+ loweredOp.addTypes(rawOutType);
+ loweredOp.addOperands({a, b, c, sparseIdx});
+ loweredOp.addAttributes(attrs);
+ Operation *lowered = rewriter.create(loweredOp);
+
+ Operation *maybeCastBack = lowered;
+ if (rawOutType != outType)
+ maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
+ lowered->getResult(0));
+ rewriter.replaceOp(op, maybeCastBack->getResults());
+
+ return success();
+ }
+};
+
struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
ScaledWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
@@ -3833,11 +4074,12 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
SparseMFMAOpLowering, WMMAOpLowering, ScaledWMMAOpLowering,
- ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering,
- ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
- PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
- GatherToLDSOpLowering, TransposeLoadOpLowering,
- AMDGPUPermlaneLowering, AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
+ SparseWMMAOpLowering, ExtPackedFp8OpLowering,
+ ScaledExtPackedMatrixOpLowering, ScaledExtPackedOpLowering,
+ PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
+ PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
+ TransposeLoadOpLowering, AMDGPUPermlaneLowering,
+ AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
AMDGPULowerDescriptor<MakeDmaDescriptorOp>,
AMDGPULowerDescriptor<MakeGatherDmaDescriptorOp>,
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp
index f452d2de15dc8..b715f4ab93231 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp
@@ -670,6 +670,77 @@ LogicalResult SparseMFMAOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// SparseWMMAOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult SparseWMMAOp::verify() {
+ auto sparseType = cast<VectorType>(getSourceA().getType());
+ auto denseType = cast<VectorType>(getSourceB().getType());
+ auto destType = cast<VectorType>(getDestC().getType());
+
+ Type sparseElem = sparseType.getElementType();
+ Type denseElem = denseType.getElementType();
+ Type destElem = destType.getElementType();
+ int64_t sparseLen = sparseType.getNumElements();
+ int64_t denseLen = denseType.getNumElements();
+ int64_t destLen = destType.getNumElements();
+
+ uint32_t m = getM(), n = getN(), k = getK();
+ if ((m != 16) || (n != 16))
+ return emitOpError("expected MxN to be exactly 16x16");
+
+ const bool isWavesize64 = getWave64();
+ ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/183360
More information about the Mlir-commits
mailing list