[Mlir-commits] [mlir] [MLIR] Supported sparse MMA intrinsics in the MLIR->NVVM IR->NVPTX flow (PR #168686)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 19 01:45:32 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Kirill Vedernikov (kvederni)
<details>
<summary>Changes</summary>
This change adds sparse MMA intrinsics to the MLIR -> NVVM IR -> NVPTX flow. NVVM and NVPTX implementation is based on PTX ISA 9.0.
---
Patch is 97.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/168686.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+276-1)
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+474)
- (added) mlir/test/Dialect/LLVMIR/nvvm-mma-sp-kind.mlir (+221)
- (added) mlir/test/Dialect/LLVMIR/nvvm-mma-sp-ordered.mlir (+411)
- (added) mlir/test/Dialect/LLVMIR/nvvm-mma-sp.mlir (+390)
``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 8d5bc7333d47f..b8f69f6b2cb98 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1955,6 +1955,12 @@ class WMMA_NAME_LDST<string Op, WMMA_REGS Frag, string Layout, int WithStride> {
/// Generate the signature part of the mma intrinsic name.
class MMA_SIGNATURE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
list<WMMA_REGS> id_frags = !cond(
+ // FP8/F8F6F4 ops are identified by A,B inputs & accomulator & result type.
+ !or(!eq(A.ptx_elt_type, "e4m3"),
+ !eq(A.ptx_elt_type, "e5m2"),
+ !eq(A.ptx_elt_type, "e3m2"),
+ !eq(A.ptx_elt_type, "e2m3"),
+ !eq(A.ptx_elt_type, "e2m1")): [D, A, B, C],
// FP16 ops are identified by accumulator & result type.
!eq(A.ptx_elt_type, "f16") : [D, C],
// other ops are identified by input types.
@@ -2081,6 +2087,31 @@ class NVVM_MMA_OPS {
list<list<WMMA_REGS>> all_mma_sync_ops = !listconcat(
tf32_mma_ops, bf16_mma_ops, f64_mma_ops,
fp_mma_ops, int_mma_ops, subint_mma_ops, bit_mma_ops);
+
+ list<list<WMMA_REGS>> bf16_mma_sp_ops = MMA_OPS<
+ [GEOM<16,8,16>, GEOM<16,8,32>],
+ ["bf16"], [], ["f32"], []>.ret;
+ list<list<WMMA_REGS>> tf32_mma_sp_ops = MMA_OPS<
+ [GEOM<16,8,8>, GEOM<16,8,16>],
+ ["tf32"], [], ["f32"], []>.ret;
+ list<list<WMMA_REGS>> fp_mma_sp_ops = MMA_OPS<
+ [GEOM<16,8,16>, GEOM<16,8,32>],
+ ["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret;
+ list<list<WMMA_REGS>> fp8_mma_sp_ops = MMA_OPS<
+ [GEOM<16,8,64>],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["f16", "f32"], ["f16", "f32"]>.ret;
+ list<list<WMMA_REGS>> subint_mma_sp_ops = MMA_OPS<
+ [GEOM<16,8,64>, GEOM<16,8,128>],
+ ["s4", "u4"], ["s4", "u4"], ["s32"], []>.ret;
+ list<list<WMMA_REGS>> int_mma_sp_ops = MMA_OPS<
+ [GEOM<16,8,32>, GEOM<16,8,64>],
+ ["s8", "u8"], ["s8", "u8"], ["s32"], []>.ret;
+ list<list<WMMA_REGS>> all_mma_sp_sync_ops = !listconcat(
+ bf16_mma_sp_ops, tf32_mma_sp_ops, fp_mma_sp_ops, fp8_mma_sp_ops,
+ subint_mma_sp_ops, int_mma_sp_ops);
+
}
def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -2187,6 +2218,29 @@ def MMAIntOverflowAttr : EnumAttr<NVVM_Dialect, MMAIntOverflow, "mma_int_overflo
let assemblyFormat = "`<` $value `>`";
}
+/// Sparse MMA metadata types
+def MMASpMetadataStandard : I32EnumAttrCase<"standard", 0>;
+def MMASpMetadataOrdered : I32EnumAttrCase<"ordered", 1>;
+def MMASpMetadata : I32EnumAttr<"MMASpMetadata", "Sparse MMA metadata ordering",
+ [MMASpMetadataStandard, MMASpMetadataOrdered]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+def MMASpMetadataAttr : EnumAttr<NVVM_Dialect, MMASpMetadata, "mma_sp_metadata"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+/// MMA kind types (for mixed-precision FP8 operations)
+def MMAKindF8F6F4 : I32EnumAttrCase<"f8f6f4", 0>;
+def MMAKind : I32EnumAttr<"MMAKind", "MMA operation kind",
+ [MMAKindF8F6F4]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+def MMAKindAttr : EnumAttr<NVVM_Dialect, MMAKind, "mma_kind"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
/// Attribute to hold the MMA shape
def NVVM_MMAShapeAttr : NVVM_Attr<"MMAShape", "shape"> {
let summary = "Attribute for MMA operation shape.";
@@ -2330,12 +2384,18 @@ def MMATypeU4 : I32EnumAttrCase<"u4", 7>;
def MMATypeS4 : I32EnumAttrCase<"s4", 8>;
def MMATypeBF16 : I32EnumAttrCase<"bf16", 9>;
def MMATypeF64 : I32EnumAttrCase<"f64", 10>;
+def MMATypeE4M3 : I32EnumAttrCase<"e4m3", 11>;
+def MMATypeE5M2 : I32EnumAttrCase<"e5m2", 12>;
+def MMATypeE3M2 : I32EnumAttrCase<"e3m2", 13>;
+def MMATypeE2M3 : I32EnumAttrCase<"e2m3", 14>;
+def MMATypeE2M1 : I32EnumAttrCase<"e2m1", 15>;
def MMATypes : I32EnumAttr<"MMATypes", "NVVM MMA types",
[MMATypeF16, MMATypeF32, MMATypeTF32,
MMATypeBF16, MMATypeS8, MMATypeU8,
MMATypeS32, MMATypeS4, MMATypeU4,
- MMATypeB1, MMATypeF64]> {
+ MMATypeB1, MMATypeF64,
+ MMATypeE4M3, MMATypeE5M2, MMATypeE3M2, MMATypeE2M3, MMATypeE2M1]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}
@@ -2772,6 +2832,221 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
let hasVerifier = 1;
}
+/// Generate enum value of the mma.sync intrinsic.
+class MMA_SP_SYNC_NAME<string Metadata, string Kind, int Satfinite,
+ WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
+ string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+ string id = "llvm::Intrinsic::nvvm_mma"
+ # "_" # !subst("::", "_", Metadata)
+ # "_" # A.geom
+ # "_row_col"
+ # !if(!ne(Kind, ""), !strconcat("_", !subst("::", "_", Kind)), "")
+ # !if(Satfinite, "_satfinite", "")
+ # signature;
+}
+
+// Returns true if this combination of layout/kind/satf for MMA.SP ops is supported;
+// false otherwise.
+// E.g.
+// if NVVM_MMA_SP_SUPPORTED<...>.ret then
+// def : FOO<>; // The record will only be defined for supported ops.
+//
+class NVVM_MMA_SP_SUPPORTED<list<WMMA_REGS> frags, string metadata,
+ string kind, int satf> {
+ // MMA.SP ops check both layouts.
+ string a_type = frags[0].ptx_elt_type;
+ string b_type = frags[1].ptx_elt_type;
+ string c_type = frags[2].ptx_elt_type;
+ string d_type = frags[3].ptx_elt_type;
+ string geom = frags[0].geom;
+
+ bit is_int = !or(!eq(a_type, "s8"),
+ !eq(a_type, "u8"),
+ !eq(a_type, "s4"),
+ !eq(a_type, "u4"));
+
+ bit ret = !cond(
+
+ // Limit satf to valid types
+ !and(!eq(satf, 1),
+ !eq(is_int, 0)): false,
+
+ // f16/bf16/tf32 requires A and B to be the same type.
+ !and(!or(!eq(a_type, "f16"),
+ !eq(a_type, "bf16"),
+ !eq(a_type, "tf32")),
+ !ne(a_type, b_type)): false,
+
+ // m16n8k16, m16n8k32 and m16n8k64 requires C and D to be the same type.
+ !and(!or(!eq(geom, "m16n8k16"),
+ !eq(geom, "m16n8k32"),
+ !eq(geom, "m16n8k64")),
+ !ne(c_type, d_type)): false,
+
+ !and(!eq(kind, ""),
+ !or(!eq(a_type, "e3m2"),
+ !eq(a_type, "e2m3"),
+ !eq(a_type, "e2m1"),
+ !eq(b_type, "e3m2"),
+ !eq(b_type, "e2m3"),
+ !eq(b_type, "e2m1"))): false,
+
+ !and(!eq(kind, ""),
+ !eq(geom, "m16n8k64"),
+ !or(!eq(c_type, "f16"),
+ !eq(d_type, "f16"))): false,
+
+ !and(!ne(kind, ""),
+ !or(!eq(metadata, "sp"),
+ !ne(geom, "m16n8k64"),
+ !eq(is_int, 1))): false,
+
+ // All other are OK.
+ true: true
+ );
+}
+
+/// Helper to create the mapping between the configuration and the mma.sp.sync
+/// intrinsic enum value.
+class MMA_SP_SYNC_INTR {
+ list<list<list<list<string>>>> cond0 =
+ !foreach(op, NVVM_MMA_OPS.all_mma_sp_sync_ops,
+ !foreach(metadata, ["sp", "sp::ordered_metadata"],
+ !foreach(kind, ["", "kind::f8f6f4"],
+ !foreach (satf, [0, 1],
+ !if(NVVM_MMA_SP_SUPPORTED<op, metadata, kind, satf>.ret,
+ "if (m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k
+ # " && \"" # op[0].ptx_elt_type # "\" == eltypeA"
+ # " && \"" # op[1].ptx_elt_type # "\" == eltypeB"
+ # " && \"" # op[2].ptx_elt_type # "\" == eltypeC"
+ # " && \"" # op[3].ptx_elt_type # "\" == eltypeD"
+ # " && (satf.has_value() ? " # satf # " == static_cast<int>(*satf) : true)"
+ # " && " # !if(!eq(metadata, "sp"), "!orderedMetadata", "orderedMetadata")
+ # " && " # !if(!eq(kind, ""), "!hasKind", "hasKind") # ")\n"
+ # " return " #
+ MMA_SP_SYNC_NAME<metadata, kind, satf, op[0], op[1], op[2], op[3]>.id # ";",
+ "") // if supported
+ ) // satf
+ ) // kind
+ ) // metadata
+ ); // all_mma_sp_sync_ops
+ list<list<list<string>>> f1 = !foldl([[[""]]], cond0, acc, el,
+ !listconcat(acc, el));
+ list<list<string>> f2 = !foldl([[""]], f1, acc, el, !listconcat(acc, el));
+ list<string> f3 = !foldl([""], f2, acc, el, !listconcat(acc, el));
+ string id = !foldl("", f3, acc, el, acc # "\n" # el);
+}
+
+def NVVM_MmaSpOp : NVVM_Op<"mma.sp.sync", [AttrSizedOperandSegments]> {
+
+ let summary = "cooperative sparse matrix-multiply and accumulate";
+
+ let description = [{
+ The `nvvm.mma.sp.sync` operation collectively performs the sparse operation
+ `D = matmul(A_sparse, B) + C` using all threads in a warp.
+
+ This operation is similar to `nvvm.mma.sync` but with structured sparsity
+ in the A operand. The sparsity follows the 2:4 structured sparse pattern
+ where 2 out of every 4 elements are non-zero.
+
+ All the threads in the warp must execute the same `mma.sp.sync` operation.
+
+ The `sparseMetadata` operand provides the sparsity indices that indicate
+ which elements in the A operand are non-zero. The `sparsitySelector`
+ controls how the indices are distributed among threads in the warp and
+ should typically be 0 or 1.
+
+ The optional `metadataType` attribute specifies the metadata ordering:
+ - `standard` (default): Uses standard sparse metadata ordering
+ - `ordered`: Uses ordered metadata (PTX ISA 8.5+, sm_90+)
+
+ The optional `kind` attribute specifies mixed-precision modes for FP8 operations:
+ - `f8f6f4`: Enables e3m2, e2m3, e2m1 FP8 types and f16 accumulator (PTX ISA 8.7+, sm_90+)
+ - Only valid with ordered metadata and m16n8k64 shape
+
+ The shapes, layouts, and data types follow the same constraints as the
+ regular `nvvm.mma.sync` operation, but the A operand contains only the
+ non-zero elements in compressed format.
+
+ Example:
+ ```mlir
+ %d = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}}
+ : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+
+ // With ordered metadata:
+ %d = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {metadataType = #nvvm.mma_sp_metadata<ordered>,
+ shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}}
+ : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ ```
+ }];
+
+ let results = (outs LLVM_AnyStruct:$res);
+ let arguments = (ins NVVM_MMAShapeAttr:$shape,
+ OptionalAttr<MMAIntOverflowAttr>:$intOverflowBehavior,
+ OptionalAttr<MMATypesAttr>:$multiplicandAPtxType,
+ OptionalAttr<MMATypesAttr>:$multiplicandBPtxType,
+ OptionalAttr<MMASpMetadataAttr>:$metadataType,
+ OptionalAttr<MMAKindAttr>:$kind,
+ Variadic<LLVM_Type>:$operandA,
+ Variadic<LLVM_Type>:$operandB,
+ Variadic<LLVM_Type>:$operandC,
+ I32:$sparseMetadata,
+ I32:$sparsitySelector);
+
+ let extraClassDeclaration = !strconcat([{
+ static llvm::Intrinsic::ID getIntrinsicID(
+ int64_t m, int64_t n, uint64_t k,
+ std::optional<MMAIntOverflow> satf,
+ std::optional<MMASpMetadata> metadata,
+ std::optional<MMAKind> kind,
+ mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum,
+ mlir::NVVM::MMATypes eltypeCEnum, mlir::NVVM::MMATypes eltypeDEnum) {
+ llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum);
+ llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum);
+ llvm::StringRef eltypeC = stringifyEnum(eltypeCEnum);
+ llvm::StringRef eltypeD = stringifyEnum(eltypeDEnum);
+ bool orderedMetadata = metadata.has_value() &&
+ *metadata == MMASpMetadata::ordered;
+ bool hasKind = kind.has_value();
+ }],
+ MMA_SP_SYNC_INTR<>.id, [{
+ return 0;
+ }
+
+ static std::optional<mlir::NVVM::MMATypes> inferOperandMMAType(Type operandElType,
+ bool isAccumulator);
+
+ MMATypes accumPtxType();
+ MMATypes resultPtxType();
+
+ static mlir::NVVM::IDArgPair
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase& builder);
+ }]);
+
+ let builders = [
+ OpBuilder<(ins "Type":$resultType, "ValueRange":$operandA,
+ "ValueRange":$operandB, "ValueRange":$operandC,
+ "Value":$sparseMetadata, "Value":$sparsitySelector,
+ "ArrayRef<int64_t>":$shape,
+ "std::optional<MMAIntOverflow>":$intOverflow,
+ "std::optional<std::array<MMATypes, 2>>":$multiplicandPtxTypes)>
+ ];
+
+ string llvmBuilder = [{
+ auto [id, args] = NVVM::MmaSpOp::getIntrinsicIDAndArgs(
+ *op, moduleTranslation, builder);
+ $res = createIntrinsicCall(builder, id, args);
+ }];
+
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// NVVM TMA Ops
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 7ac427dbe3941..8db724dd0a25b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -940,6 +940,480 @@ LogicalResult MmaOp::verify() {
return success();
}
+MMATypes MmaSpOp::accumPtxType() {
+ std::optional<mlir::NVVM::MMATypes> val = MmaOp::inferOperandMMAType(
+ getODSOperands(2).getTypes().front(), /*isAccumulator=*/true);
+ assert(val.has_value() && "accumulator PTX type should always be inferrable");
+ return val.value();
+}
+
+MMATypes MmaSpOp::resultPtxType() {
+ std::optional<mlir::NVVM::MMATypes> val =
+ MmaOp::inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true);
+ assert(val.has_value() && "result PTX type should always be inferrable");
+ return val.value();
+}
+
+mlir::NVVM::IDArgPair
+MmaSpOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MmaSpOp>(op);
+
+ // Get operands
+ llvm::SmallVector<llvm::Value *> args;
+ for (mlir::Value v : thisOp.getOperands())
+ args.push_back(mt.lookupValue(v));
+
+ // Get intrinsic ID using the existing getIntrinsicID method
+ auto intId = MmaSpOp::getIntrinsicID(
+ thisOp.getShape().getM(), thisOp.getShape().getN(), thisOp.getShape().getK(),
+ thisOp.getIntOverflowBehavior(),
+ thisOp.getMetadataType(),
+ thisOp.getKind(),
+ *thisOp.getMultiplicandAPtxType(),
+ *thisOp.getMultiplicandBPtxType(),
+ thisOp.accumPtxType(),
+ thisOp.resultPtxType());
+
+ return {intId, args};
+}
+
+void MmaSpOp::print(OpAsmPrinter &p) {
+ SmallVector<Type, 4> regTypes;
+ struct OperandFragment {
+ StringRef operandName;
+ StringRef ptxTypeAttr;
+ SmallVector<Value, 4> regs;
+ explicit OperandFragment(StringRef name, StringRef ptxTypeName)
+ : operandName(name), ptxTypeAttr(ptxTypeName) {}
+ };
+
+ std::array<OperandFragment, 5> frags{
+ OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
+ OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
+ OperandFragment("C", ""),
+ OperandFragment("sparseMetadata", ""),
+ OperandFragment("selector", "")};
+ SmallVector<StringRef, 4> ignoreAttrNames{
+ mlir::NVVM::MmaSpOp::getOperandSegmentSizeAttr()};
+
+ // Handle variadic operands A, B, C
+ for (unsigned fragIdx = 0; fragIdx < 3; fragIdx++) {
+ auto &frag = frags[fragIdx];
+ auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
+ for (auto operandIdx = varOperandSpec.first;
+ operandIdx < varOperandSpec.first + varOperandSpec.second;
+ operandIdx++) {
+ frag.regs.push_back(this->getOperand(operandIdx));
+ if (operandIdx == varOperandSpec.first) {
+ regTypes.push_back(this->getOperand(operandIdx).getType());
+ }
+ }
+ std::optional<MMATypes> inferredType =
+ MmaOp::inferOperandMMAType(regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
+ if (inferredType)
+ ignoreAttrNames.push_back(frag.ptxTypeAttr);
+ }
+
+ // Handle sparse metadata and selector (single operands)
+ frags[3].regs.push_back(getSparseMetadata());
+ frags[4].regs.push_back(getSparsitySelector());
+
+ auto printMmaSpOperand = [&](const OperandFragment &frag) -> void {
+ p << " " << frag.operandName;
+ p << "[";
+ p.printOperands(frag.regs);
+ p << "]";
+ };
+
+ for (const auto &frag : frags)
+ printMmaSpOperand(frag);
+
+ p.printOptionalAttrDict((*this)->getAttrs(), ignoreAttrNames);
+ p << " : ";
+ p << "(";
+ for (int i = 0; i < 3; ++i) {
+ p << regTypes[i];
+ if (i < 2) p << ", ";
+ }
+ p << ") -> " << getResult().getType();
+}
+
+void MmaSpOp::build(OpBuilder &builder, OperationState &result,
+ Type resultType, ValueRange operandA, ValueRange operandB,
+ ValueRange operandC, Value sparseMetadata, Value sparsitySelector,
+ ArrayRef<int64_t> shape,
+ std::optional<MMAIntOverflow> intOverflow,
+ std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
+
+ assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
+ MLIRContext *ctx = builder.getContext();
+ result.addAttribute(
+ "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
+
+ result.addOperands(operandA);
+ result.addOperands(operandB);
+ result.addOperands(operandC);
+ result.addOperands(sparseMetadata);
+ result.addOperands(sparsitySelector);
+
+ if (multiplicandPtxTypes) {
+ result.addAttribute("multiplicandAPtxType",
+ MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
+ result.addAttribute("multiplicandBPtxType",
+ MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
+ } else {
+ if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false))
+ result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
+ if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false))
+ result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
+ }
+
+ if (intOverflow.has_value())
+ result.addAttribute("intOverflowBehavior",
+ MMAIntOverflowAttr::get(ctx, *intOverflow));
+
+ result.addTypes(resultType);
+ result.addAttribute(
+ MmaSpOp::getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
+ static_cast<int32_t>(operandB.size()),
+ static_cast<int32_t>(operandC.size()),
+ 1, 1})); // sparseMetadata and sparsitySelector
+}
+
+ParseResult MmaSpOp::parse(OpAsmParser &parser, OperationState &result) {
+ struct OperandFragment {
+ std::optional<MMATypes> elemtype;
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
+ SmallVector<Type> regTypes;
+ };
+
+ Builder &builder = parser.getBuilder();
+ std::array<OperandFragment, 6> frags; // A, B, C, sparseMetadata, selector
+
+ NamedAttrList namedAttributes;
+
+ // A helper to parse the operand segments.
+ auto parseMmaSpOperand = [&](StringRef operandName,
+ OperandFragment &frag) -> LogicalResult {
+ if (parser.parseKeyword(operandName).failed())
+ return failure();
+ if (parser
+ .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
+ .failed())
+ return failure();
+ return success();
+ };
+
+ // Parse the operand segments.
+ if (parseMmaSpOperand("A", frags[0]).failed())
+ return failure();
+ if (parseMmaSpOperand("B", frags[1]).failed(...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/168686
More information about the Mlir-commits
mailing list