[Mlir-commits] [mlir] a7c8505 - [MLIR] Supported sparse MMA intrinsics in the MLIR->NVVM IR->NVPTX flow (#168686)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 28 02:03:37 PST 2025
Author: Kirill Vedernikov
Date: 2025-11-28T15:33:32+05:30
New Revision: a7c85052ebe7813da50cd461fdccccacb296017a
URL: https://github.com/llvm/llvm-project/commit/a7c85052ebe7813da50cd461fdccccacb296017a
DIFF: https://github.com/llvm/llvm-project/commit/a7c85052ebe7813da50cd461fdccccacb296017a.diff
LOG: [MLIR] Supported sparse MMA intrinsics in the MLIR->NVVM IR->NVPTX flow (#168686)
This change adds sparse MMA intrinsics to the MLIR -> NVVM IR -> NVPTX
flow. NVVM and NVPTX implementation is based on PTX ISA 9.0.
Added:
mlir/test/Dialect/LLVMIR/nvvm-mma-sp-kind.mlir
mlir/test/Dialect/LLVMIR/nvvm-mma-sp-ordered.mlir
mlir/test/Dialect/LLVMIR/nvvm-mma-sp.mlir
Modified:
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index d78145d690fc8..c427b0942aafd 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2131,6 +2131,12 @@ class WMMA_NAME_LDST<string Op, WMMA_REGS Frag, string Layout, int WithStride> {
/// Generate the signature part of the mma intrinsic name.
class MMA_SIGNATURE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
list<WMMA_REGS> id_frags = !cond(
+ // FP8/F8F6F4 ops are identified by A,B inputs & accomulator & result type.
+ !or(!eq(A.ptx_elt_type, "e4m3"),
+ !eq(A.ptx_elt_type, "e5m2"),
+ !eq(A.ptx_elt_type, "e3m2"),
+ !eq(A.ptx_elt_type, "e2m3"),
+ !eq(A.ptx_elt_type, "e2m1")): [D, A, B, C],
// FP16 ops are identified by accumulator & result type.
!eq(A.ptx_elt_type, "f16") : [D, C],
// other ops are identified by input types.
@@ -2257,6 +2263,31 @@ class NVVM_MMA_OPS {
list<list<WMMA_REGS>> all_mma_sync_ops = !listconcat(
tf32_mma_ops, bf16_mma_ops, f64_mma_ops,
fp_mma_ops, int_mma_ops, subint_mma_ops, bit_mma_ops);
+
+ list<list<WMMA_REGS>> bf16_mma_sp_ops = MMA_OPS<
+ [GEOM<16,8,16>, GEOM<16,8,32>],
+ ["bf16"], [], ["f32"], []>.ret;
+ list<list<WMMA_REGS>> tf32_mma_sp_ops = MMA_OPS<
+ [GEOM<16,8,8>, GEOM<16,8,16>],
+ ["tf32"], [], ["f32"], []>.ret;
+ list<list<WMMA_REGS>> fp_mma_sp_ops = MMA_OPS<
+ [GEOM<16,8,16>, GEOM<16,8,32>],
+ ["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret;
+ list<list<WMMA_REGS>> fp8_mma_sp_ops = MMA_OPS<
+ [GEOM<16,8,64>],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["f16", "f32"], ["f16", "f32"]>.ret;
+ list<list<WMMA_REGS>> subint_mma_sp_ops = MMA_OPS<
+ [GEOM<16,8,64>, GEOM<16,8,128>],
+ ["s4", "u4"], ["s4", "u4"], ["s32"], []>.ret;
+ list<list<WMMA_REGS>> int_mma_sp_ops = MMA_OPS<
+ [GEOM<16,8,32>, GEOM<16,8,64>],
+ ["s8", "u8"], ["s8", "u8"], ["s32"], []>.ret;
+ list<list<WMMA_REGS>> all_mma_sp_sync_ops = !listconcat(
+ bf16_mma_sp_ops, tf32_mma_sp_ops, fp_mma_sp_ops, fp8_mma_sp_ops,
+ subint_mma_sp_ops, int_mma_sp_ops);
+
}
def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -2362,6 +2393,16 @@ def MMAIntOverflow : I32EnumAttr<"MMAIntOverflow", "MMA overflow options",
def MMAIntOverflowAttr : EnumAttr<NVVM_Dialect, MMAIntOverflow, "mma_int_overflow"> {
let assemblyFormat = "`<` $value `>`";
}
+/// MMA kind types (for mixed-precision FP8 operations)
+def MMAKindF8F6F4 : I32EnumAttrCase<"f8f6f4", 0>;
+def MMAKind : I32EnumAttr<"MMAKind", "MMA operation kind",
+ [MMAKindF8F6F4]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+def MMAKindAttr : EnumAttr<NVVM_Dialect, MMAKind, "mma_kind"> {
+ let assemblyFormat = "`<` $value `>`";
+}
/// Attribute to hold the MMA shape
def NVVM_MMAShapeAttr : NVVM_Attr<"MMAShape", "shape"> {
@@ -2506,12 +2547,18 @@ def MMATypeU4 : I32EnumAttrCase<"u4", 7>;
def MMATypeS4 : I32EnumAttrCase<"s4", 8>;
def MMATypeBF16 : I32EnumAttrCase<"bf16", 9>;
def MMATypeF64 : I32EnumAttrCase<"f64", 10>;
+def MMATypeE4M3 : I32EnumAttrCase<"e4m3", 11>;
+def MMATypeE5M2 : I32EnumAttrCase<"e5m2", 12>;
+def MMATypeE3M2 : I32EnumAttrCase<"e3m2", 13>;
+def MMATypeE2M3 : I32EnumAttrCase<"e2m3", 14>;
+def MMATypeE2M1 : I32EnumAttrCase<"e2m1", 15>;
def MMATypes : I32EnumAttr<"MMATypes", "NVVM MMA types",
[MMATypeF16, MMATypeF32, MMATypeTF32,
MMATypeBF16, MMATypeS8, MMATypeU8,
MMATypeS32, MMATypeS4, MMATypeU4,
- MMATypeB1, MMATypeF64]> {
+ MMATypeB1, MMATypeF64,
+ MMATypeE4M3, MMATypeE5M2, MMATypeE3M2, MMATypeE2M3, MMATypeE2M1]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}
@@ -2948,6 +2995,216 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
let hasVerifier = 1;
}
+/// Generate enum value of the mma.sync intrinsic.
+class MMA_SP_SYNC_NAME<string Metadata, string Kind, int Satfinite,
+ WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
+ string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+ string id = "llvm::Intrinsic::nvvm_mma"
+ # "_" # !subst("::", "_", Metadata)
+ # "_" # A.geom
+ # "_row_col"
+ # !if(!ne(Kind, ""), !strconcat("_", !subst("::", "_", Kind)), "")
+ # !if(Satfinite, "_satfinite", "")
+ # signature;
+}
+
+// Returns true if this combination of layout/kind/satf for MMA.SP ops is supported;
+// false otherwise.
+// E.g.
+// if NVVM_MMA_SP_SUPPORTED<...>.ret then
+// def : FOO<>; // The record will only be defined for supported ops.
+//
+class NVVM_MMA_SP_SUPPORTED<list<WMMA_REGS> frags, string metadata,
+ string kind, int satf> {
+ // MMA.SP ops check both layouts.
+ string a_type = frags[0].ptx_elt_type;
+ string b_type = frags[1].ptx_elt_type;
+ string c_type = frags[2].ptx_elt_type;
+ string d_type = frags[3].ptx_elt_type;
+ string geom = frags[0].geom;
+
+ bit is_int = !or(!eq(a_type, "s8"),
+ !eq(a_type, "u8"),
+ !eq(a_type, "s4"),
+ !eq(a_type, "u4"));
+
+ bit ret = !cond(
+
+ // Limit satf to valid types
+ !and(!eq(satf, 1),
+ !eq(is_int, 0)): false,
+
+ // f16/bf16/tf32 requires A and B to be the same type.
+ !and(!or(!eq(a_type, "f16"),
+ !eq(a_type, "bf16"),
+ !eq(a_type, "tf32")),
+ !ne(a_type, b_type)): false,
+
+ // m16n8k16, m16n8k32 and m16n8k64 requires C and D to be the same type.
+ !and(!or(!eq(geom, "m16n8k16"),
+ !eq(geom, "m16n8k32"),
+ !eq(geom, "m16n8k64")),
+ !ne(c_type, d_type)): false,
+
+ !and(!eq(kind, ""),
+ !or(!eq(a_type, "e3m2"),
+ !eq(a_type, "e2m3"),
+ !eq(a_type, "e2m1"),
+ !eq(b_type, "e3m2"),
+ !eq(b_type, "e2m3"),
+ !eq(b_type, "e2m1"))): false,
+
+ !and(!eq(kind, ""),
+ !eq(geom, "m16n8k64"),
+ !or(!eq(c_type, "f16"),
+ !eq(d_type, "f16"))): false,
+
+ !and(!ne(kind, ""),
+ !or(!eq(metadata, "sp"),
+ !ne(geom, "m16n8k64"),
+ !eq(is_int, 1))): false,
+
+ // All other are OK.
+ true: true
+ );
+}
+
+/// Helper to create the mapping between the configuration and the mma.sp.sync
+/// intrinsic enum value.
+class MMA_SP_SYNC_INTR {
+ list<list<list<list<string>>>> cond0 =
+ !foreach(op, NVVM_MMA_OPS.all_mma_sp_sync_ops,
+ !foreach(metadata, ["sp", "sp::ordered_metadata"],
+ !foreach(kind, ["", "kind::f8f6f4"],
+ !foreach (satf, [0, 1],
+ !if(NVVM_MMA_SP_SUPPORTED<op, metadata, kind, satf>.ret,
+ "if (m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k
+ # " && \"" # op[0].ptx_elt_type # "\" == eltypeA"
+ # " && \"" # op[1].ptx_elt_type # "\" == eltypeB"
+ # " && \"" # op[2].ptx_elt_type # "\" == eltypeC"
+ # " && \"" # op[3].ptx_elt_type # "\" == eltypeD"
+ # " && (satf.has_value() ? " # satf # " == static_cast<int>(*satf) : true)"
+ # " && " # !if(!eq(metadata, "sp"), "!orderedMetadata", "orderedMetadata") # ")\n"
+ # " return " #
+ MMA_SP_SYNC_NAME<metadata, kind, satf, op[0], op[1], op[2], op[3]>.id # ";",
+ "") // if supported
+ ) // satf
+ ) // kind
+ ) // metadata
+ ); // all_mma_sp_sync_ops
+ list<list<list<string>>> f1 = !foldl([[[""]]], cond0, acc, el,
+ !listconcat(acc, el));
+ list<list<string>> f2 = !foldl([[""]], f1, acc, el, !listconcat(acc, el));
+ list<string> f3 = !foldl([""], f2, acc, el, !listconcat(acc, el));
+ string id = !foldl("", f3, acc, el, acc # "\n" # el);
+}
+
+def NVVM_MmaSpOp : NVVM_Op<"mma.sp.sync", [AttrSizedOperandSegments]> {
+
+ let summary = "cooperative sparse matrix-multiply and accumulate";
+
+ let description = [{
+ The `nvvm.mma.sp.sync` operation collectively performs the sparse operation
+ `D = matmul(A_sparse, B) + C` using all threads in a warp.
+
+ This operation is similar to `nvvm.mma.sync` but with structured sparsity
+ in the A operand. The sparsity follows the 2:4 structured sparse pattern
+ where 2 out of every 4 elements are non-zero.
+
+ All the threads in the warp must execute the same `mma.sp.sync` operation.
+
+ The `sparseMetadata` operand provides the sparsity indices that indicate
+ which elements in the A operand are non-zero. The `sparsitySelector`
+ controls how the indices are distributed among threads in the warp and
+ should typically be 0 or 1.
+
+ The optional `orderedMetadata` attribute specifies the metadata ordering:
+ - Absence (default): Uses standard sparse metadata ordering
+ - Presence: Uses ordered metadata (PTX ISA 8.5+, sm_90+)
+
+ The optional `kind` attribute specifies mixed-precision modes for FP8 operations:
+ - `f8f6f4`: Enables e3m2, e2m3, e2m1 FP8 types and f16 accumulator (PTX ISA 8.7+, sm_90+)
+ - Only valid with ordered metadata and m16n8k64 shape
+
+ The shapes, layouts, and data types follow the same constraints as the
+ regular `nvvm.mma.sync` operation, but the A operand contains only the
+ non-zero elements in compressed format.
+
+ Example:
+ ```mlir
+ %d = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}}
+ : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+
+ // With ordered metadata:
+ %d = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata, shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}}
+ : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ ```
+ }];
+
+ let results = (outs LLVM_AnyStruct:$res);
+ let arguments = (ins NVVM_MMAShapeAttr:$shape,
+ OptionalAttr<MMAIntOverflowAttr>:$intOverflowBehavior,
+ OptionalAttr<MMATypesAttr>:$multiplicandAPtxType,
+ OptionalAttr<MMATypesAttr>:$multiplicandBPtxType,
+ UnitAttr:$orderedMetadata,
+ OptionalAttr<MMAKindAttr>:$kind,
+ Variadic<LLVM_Type>:$operandA,
+ Variadic<LLVM_Type>:$operandB,
+ Variadic<LLVM_Type>:$operandC,
+ I32:$sparseMetadata,
+ I32:$sparsitySelector);
+
+ let extraClassDeclaration = !strconcat([{
+ static llvm::Intrinsic::ID getIntrinsicID(
+ int64_t m, int64_t n, uint64_t k,
+ std::optional<MMAIntOverflow> satf,
+ bool orderedMetadata,
+ std::optional<MMAKind> kind,
+ mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum,
+ mlir::NVVM::MMATypes eltypeCEnum, mlir::NVVM::MMATypes eltypeDEnum) {
+ llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum);
+ llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum);
+ llvm::StringRef eltypeC = stringifyEnum(eltypeCEnum);
+ llvm::StringRef eltypeD = stringifyEnum(eltypeDEnum);
+ }],
+ MMA_SP_SYNC_INTR<>.id, [{
+ return 0;
+ }
+
+ static std::optional<mlir::NVVM::MMATypes> inferOperandMMAType(Type operandElType,
+ bool isAccumulator);
+
+ MMATypes accumPtxType();
+ MMATypes resultPtxType();
+
+ static mlir::NVVM::IDArgPair
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase& builder);
+ }]);
+
+ let builders = [
+ OpBuilder<(ins "Type":$resultType, "ValueRange":$operandA,
+ "ValueRange":$operandB, "ValueRange":$operandC,
+ "Value":$sparseMetadata, "Value":$sparsitySelector,
+ "ArrayRef<int64_t>":$shape,
+ "std::optional<MMAIntOverflow>":$intOverflow,
+ "std::optional<std::array<MMATypes, 2>>":$multiplicandPtxTypes)>
+ ];
+
+ string llvmBuilder = [{
+ auto [id, args] = NVVM::MmaSpOp::getIntrinsicIDAndArgs(
+ *op, moduleTranslation, builder);
+ $res = createIntrinsicCall(builder, id, args);
+ }];
+
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// NVVM TMA Ops
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 428bc72c88a30..846b299312d67 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1026,6 +1026,482 @@ LogicalResult MmaOp::verify() {
return success();
}
+MMATypes MmaSpOp::accumPtxType() {
+ std::optional<mlir::NVVM::MMATypes> val = MmaOp::inferOperandMMAType(
+ getODSOperands(2).getTypes().front(), /*isAccumulator=*/true);
+ assert(val.has_value() && "accumulator PTX type should always be inferrable");
+ return val.value();
+}
+
+MMATypes MmaSpOp::resultPtxType() {
+ std::optional<mlir::NVVM::MMATypes> val =
+ MmaOp::inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true);
+ assert(val.has_value() && "result PTX type should always be inferrable");
+ return val.value();
+}
+
+mlir::NVVM::IDArgPair
+MmaSpOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MmaSpOp>(op);
+
+ // Get operands
+ llvm::SmallVector<llvm::Value *> args;
+ for (mlir::Value v : thisOp.getOperands())
+ args.push_back(mt.lookupValue(v));
+
+ // Get intrinsic ID using the existing getIntrinsicID method
+ auto intId = MmaSpOp::getIntrinsicID(
+ thisOp.getShape().getM(), thisOp.getShape().getN(),
+ thisOp.getShape().getK(), thisOp.getIntOverflowBehavior(),
+ thisOp.getOrderedMetadata(), thisOp.getKind(),
+ *thisOp.getMultiplicandAPtxType(), *thisOp.getMultiplicandBPtxType(),
+ thisOp.accumPtxType(), thisOp.resultPtxType());
+
+ return {intId, args};
+}
+
+void MmaSpOp::print(OpAsmPrinter &p) {
+ SmallVector<Type, 4> regTypes;
+ struct OperandFragment {
+ StringRef operandName;
+ StringRef ptxTypeAttr;
+ SmallVector<Value, 4> regs;
+ explicit OperandFragment(StringRef name, StringRef ptxTypeName)
+ : operandName(name), ptxTypeAttr(ptxTypeName) {}
+ };
+
+ std::array<OperandFragment, 5> frags{
+ OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
+ OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
+ OperandFragment("C", ""), OperandFragment("sparseMetadata", ""),
+ OperandFragment("selector", "")};
+ SmallVector<StringRef, 4> ignoreAttrNames{
+ mlir::NVVM::MmaSpOp::getOperandSegmentSizeAttr()};
+
+ // Handle variadic operands A, B, C
+ for (unsigned fragIdx = 0; fragIdx < 3; fragIdx++) {
+ auto &frag = frags[fragIdx];
+ auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
+ for (auto operandIdx = varOperandSpec.first;
+ operandIdx < varOperandSpec.first + varOperandSpec.second;
+ operandIdx++) {
+ frag.regs.push_back(this->getOperand(operandIdx));
+ if (operandIdx == varOperandSpec.first) {
+ regTypes.push_back(this->getOperand(operandIdx).getType());
+ }
+ }
+ std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
+ regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
+ if (inferredType)
+ ignoreAttrNames.push_back(frag.ptxTypeAttr);
+ }
+
+ // Handle sparse metadata and selector (single operands)
+ frags[3].regs.push_back(getSparseMetadata());
+ frags[4].regs.push_back(getSparsitySelector());
+
+ auto printMmaSpOperand = [&](const OperandFragment &frag) -> void {
+ p << " " << frag.operandName;
+ p << "[";
+ p.printOperands(frag.regs);
+ p << "]";
+ };
+
+ for (const auto &frag : frags)
+ printMmaSpOperand(frag);
+
+ p.printOptionalAttrDict((*this)->getAttrs(), ignoreAttrNames);
+ p << " : ";
+ p << "(";
+ for (int i = 0; i < 3; ++i) {
+ p << regTypes[i];
+ if (i < 2)
+ p << ", ";
+ }
+ p << ") -> " << getResult().getType();
+}
+
+void MmaSpOp::build(
+ OpBuilder &builder, OperationState &result, Type resultType,
+ ValueRange operandA, ValueRange operandB, ValueRange operandC,
+ Value sparseMetadata, Value sparsitySelector, ArrayRef<int64_t> shape,
+ std::optional<MMAIntOverflow> intOverflow,
+ std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
+
+ assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
+ MLIRContext *ctx = builder.getContext();
+ result.addAttribute(
+ "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
+
+ result.addOperands(operandA);
+ result.addOperands(operandB);
+ result.addOperands(operandC);
+ result.addOperands(sparseMetadata);
+ result.addOperands(sparsitySelector);
+
+ if (multiplicandPtxTypes) {
+ result.addAttribute("multiplicandAPtxType",
+ MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
+ result.addAttribute("multiplicandBPtxType",
+ MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
+ } else {
+ if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false))
+ result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
+ if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false))
+ result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
+ }
+
+ if (intOverflow.has_value())
+ result.addAttribute("intOverflowBehavior",
+ MMAIntOverflowAttr::get(ctx, *intOverflow));
+
+ result.addTypes(resultType);
+ result.addAttribute(
+ MmaSpOp::getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
+ static_cast<int32_t>(operandB.size()),
+ static_cast<int32_t>(operandC.size()), 1,
+ 1})); // sparseMetadata and sparsitySelector
+}
+
+ParseResult MmaSpOp::parse(OpAsmParser &parser, OperationState &result) {
+ struct OperandFragment {
+ std::optional<MMATypes> elemtype;
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
+ SmallVector<Type> regTypes;
+ };
+
+ Builder &builder = parser.getBuilder();
+ std::array<OperandFragment, 6> frags; // A, B, C, sparseMetadata, selector
+
+ NamedAttrList namedAttributes;
+
+ // A helper to parse the operand segments.
+ auto parseMmaSpOperand = [&](StringRef operandName,
+ OperandFragment &frag) -> LogicalResult {
+ if (parser.parseKeyword(operandName).failed())
+ return failure();
+ if (parser
+ .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
+ .failed())
+ return failure();
+ return success();
+ };
+
+ // Parse the operand segments.
+ if (parseMmaSpOperand("A", frags[0]).failed())
+ return failure();
+ if (parseMmaSpOperand("B", frags[1]).failed())
+ return failure();
+ if (parseMmaSpOperand("C", frags[2]).failed())
+ return failure();
+ if (parseMmaSpOperand("sparseMetadata", frags[3]).failed())
+ return failure();
+ if (parseMmaSpOperand("selector", frags[4]).failed())
+ return failure();
+
+ if (parser.parseOptionalAttrDict(namedAttributes).failed())
+ return failure();
+
+ // Parse the type specification and resolve operands.
+ SmallVector<Type, 3> operandTypes;
+ if (failed(parser.parseColon()))
+ return failure();
+ if (failed(parser.parseLParen()))
+ return failure();
+ if (failed(parser.parseTypeList(operandTypes)))
+ return failure();
+ if (failed(parser.parseRParen()))
+ return failure();
+ if (operandTypes.size() != 3)
+ return parser.emitError(
+ parser.getNameLoc(),
+ "expected one type for each operand segment but got " +
+ Twine(operandTypes.size()) + " types");
+ for (const auto &iter : llvm::enumerate(operandTypes)) {
+ auto &frag = frags[iter.index()];
+ frag.regTypes.resize(frag.regs.size(), iter.value());
+ if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
+ parser.getNameLoc(), result.operands)))
+ return failure();
+ frag.elemtype =
+ MmaOp::inferOperandMMAType(frag.regTypes[0],
+ /*isAccumulator*/ iter.index() >= 2);
+ }
+
+ Type resultType;
+ if (parser.parseArrow() || parser.parseType(resultType))
+ return failure();
+ frags[5].elemtype =
+ MmaOp::inferOperandMMAType(resultType, /*isAccumulator*/ true);
+
+ // Resolve sparse metadata and selector (assume i32 type)
+ Type i32Type = builder.getIntegerType(32);
+ if (parser
+ .resolveOperands(frags[3].regs, i32Type, parser.getCurrentLocation(),
+ result.operands)
+ .failed())
+ return failure();
+ if (parser
+ .resolveOperands(frags[4].regs, i32Type, parser.getCurrentLocation(),
+ result.operands)
+ .failed())
+ return failure();
+
+ std::array<StringRef, 2> names{"multiplicandAPtxType",
+ "multiplicandBPtxType"};
+ for (unsigned idx = 0; idx < names.size(); idx++) {
+ const auto &frag = frags[idx];
+ std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
+ if (!frag.elemtype.has_value() && !attr.has_value()) {
+ return parser.emitError(
+ parser.getNameLoc(),
+ "attribute " + names[idx] +
+ " is not provided explicitly and cannot be inferred");
+ }
+ if (!attr.has_value())
+ result.addAttribute(
+ names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
+ }
+
+ result.addTypes(resultType);
+ if (!namedAttributes.empty())
+ result.addAttributes(namedAttributes);
+ result.addAttribute(MmaSpOp::getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr({
+ static_cast<int32_t>(frags[0].regs.size()),
+ static_cast<int32_t>(frags[1].regs.size()),
+ static_cast<int32_t>(frags[2].regs.size()),
+ 1, // sparseMetadata
+ 1 // sparsitySelector
+ }));
+ return success();
+}
+
+LogicalResult MmaSpOp::verify() {
+ MLIRContext *context = getContext();
+ auto f16Ty = Float16Type::get(context);
+ auto i32Ty = IntegerType::get(context, 32);
+ auto f16x2Ty = VectorType::get(2, f16Ty);
+ auto f32Ty = Float32Type::get(context);
+ auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
+ context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
+
+ auto s32x4StructTy =
+ LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
+ auto f32x8StructTy =
+ LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
+ auto f16x2x2StructTy =
+ LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
+ auto f32x4StructTy =
+ LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
+ auto s32x2StructTy =
+ LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
+
+ std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
+ getShapeAttr().getK()};
+
+ // These variables define the set of allowed data types for matrices A, B, C,
+ // and result.
+ using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
+ using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
+ AllowedShapes allowedShapes;
+ AllowedTypes expectedA;
+ AllowedTypes expectedB;
+ AllowedTypes expectedC;
+ SmallVector<Type> expectedResult;
+
+ // When M = 16, we just need to calculate the number of 8xk tiles, where
+ // k is a factor that depends on the data type.
+ if (mmaShape[0] == 16) {
+ int64_t kFactor;
+ Type multiplicandFragType;
+ switch (*getMultiplicandAPtxType()) {
+ case MMATypes::tf32:
+ kFactor = 4;
+ multiplicandFragType = i32Ty;
+ expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
+ context, {f32Ty, f32Ty, f32Ty, f32Ty}));
+ // Sparse MMA supports m16n8k8 and m16n8k16 for tf32
+ allowedShapes.push_back({16, 8, 8});
+ allowedShapes.push_back({16, 8, 16});
+ break;
+ case MMATypes::bf16:
+ kFactor = 8;
+ multiplicandFragType = i32Ty;
+ expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
+ context, {f32Ty, f32Ty, f32Ty, f32Ty}));
+ // Sparse MMA supports m16n8k16 and m16n8k32 for bf16
+ allowedShapes.push_back({16, 8, 16});
+ allowedShapes.push_back({16, 8, 32});
+ break;
+ case MMATypes::f16:
+ kFactor = 8;
+ multiplicandFragType = f16x2Ty;
+ expectedResult.push_back(f16x2x2StructTy);
+ expectedResult.push_back(f32x4StructTy);
+ // Sparse MMA supports m16n8k16 and m16n8k32 for f16
+ allowedShapes.push_back({16, 8, 16});
+ allowedShapes.push_back({16, 8, 32});
+ break;
+ case MMATypes::s4:
+ case MMATypes::u4:
+ kFactor = 32;
+ // Sparse MMA supports m16n8k64 and m16n8k128 for s4/u4
+ allowedShapes.push_back({16, 8, 64});
+ allowedShapes.push_back({16, 8, 128});
+ break;
+ case MMATypes::s8:
+ case MMATypes::u8:
+ kFactor = 16;
+ // Sparse MMA supports m16n8k32 and m16n8k64 for s8/u8
+ allowedShapes.push_back({16, 8, 32});
+ allowedShapes.push_back({16, 8, 64});
+ break;
+ case MMATypes::e4m3:
+ case MMATypes::e5m2:
+ case MMATypes::e3m2:
+ case MMATypes::e2m3:
+ case MMATypes::e2m1:
+ kFactor = 32;
+ multiplicandFragType = i32Ty;
+ expectedResult.push_back(f16x2x2StructTy);
+ expectedResult.push_back(f32x4StructTy);
+ // Sparse MMA supports m16n8k64 for FP8 types
+ allowedShapes.push_back({16, 8, 64});
+ break;
+ default:
+ return emitError("invalid shape or multiplicand type: " +
+ stringifyEnum(getMultiplicandAPtxType().value()));
+ }
+
+ if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
+ expectedResult.push_back(s32x4StructTy);
+ expectedC.emplace_back(4, i32Ty);
+ multiplicandFragType = i32Ty;
+ } else if (*getMultiplicandAPtxType() >= MMATypes::e4m3 &&
+ *getMultiplicandAPtxType() <= MMATypes::e2m1) {
+ // FP8 types
+ expectedC.emplace_back(2, f16x2Ty);
+ expectedC.emplace_back(4, f32Ty);
+ } else {
+ expectedC.emplace_back(2, f16x2Ty);
+ expectedC.emplace_back(4, f32Ty);
+ }
+
+ // For sparse MMA, A operand is compressed (2:4 sparsity means half the
+ // elements)
+ int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor) / 2;
+ int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
+ expectedA.emplace_back(unitA, multiplicandFragType);
+ expectedB.emplace_back(unitB, multiplicandFragType);
+
+ if (resultPtxType() != accumPtxType())
+ return emitOpError("ctype does not match dtype");
+ }
+
+ // In the M=8 case, there is only 1 possible case per data type.
+ if (mmaShape[0] == 8) {
+ if (*getMultiplicandAPtxType() == MMATypes::f16) {
+ expectedA.emplace_back(2, f16x2Ty);
+ expectedB.emplace_back(2, f16x2Ty);
+ expectedResult.push_back(f16x2x4StructTy);
+ expectedResult.push_back(f32x8StructTy);
+ expectedC.emplace_back(4, f16x2Ty);
+ expectedC.emplace_back(8, f32Ty);
+ allowedShapes.push_back({8, 8, 4});
+ }
+ if (*getMultiplicandAPtxType() == MMATypes::f64) {
+ Type f64Ty = Float64Type::get(context);
+ expectedA.emplace_back(1, f64Ty);
+ expectedB.emplace_back(1, f64Ty);
+ expectedC.emplace_back(2, f64Ty);
+ expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
+ context, SmallVector<Type>(2, f64Ty)));
+ allowedShapes.push_back({8, 8, 4});
+ }
+ if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
+ expectedA.push_back({i32Ty});
+ expectedB.push_back({i32Ty});
+ expectedC.push_back({i32Ty, i32Ty});
+ expectedResult.push_back(s32x2StructTy);
+ if (isInt4PtxType(getMultiplicandAPtxType().value()))
+ allowedShapes.push_back({8, 8, 32});
+ if (isInt8PtxType(getMultiplicandAPtxType().value()))
+ allowedShapes.push_back({8, 8, 16});
+ }
+ }
+
+ std::string errorMessage;
+ llvm::raw_string_ostream errorStream(errorMessage);
+
+ // Check that we matched an existing shape/dtype combination.
+ if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
+ !llvm::is_contained(allowedShapes, mmaShape)) {
+ errorStream << "unimplemented variant for MMA shape <";
+ llvm::interleaveComma(mmaShape, errorStream);
+ errorStream << ">";
+ return emitOpError(errorMessage);
+ }
+
+ // Verify the operand types for segments of A, B, and C operands.
+ std::array<StringRef, 3> operandNames{"A", "B", "C"};
+ for (const auto &iter : llvm::enumerate(
+ SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
+ auto spec = this->getODSOperandIndexAndLength(iter.index());
+ SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
+ operand_type_begin() + spec.first +
+ spec.second);
+ bool match = llvm::is_contained(iter.value(), operandTySeg);
+
+ if (!match) {
+ errorStream << "Could not match types for the "
+ << operandNames[iter.index()]
+ << " operands; expected one of ";
+ for (const auto &x : iter.value()) {
+ errorStream << x.size() << "x" << x[0] << " ";
+ }
+ errorStream << "but got ";
+ llvm::interleaveComma(operandTySeg, errorStream);
+ return emitOpError(errorMessage);
+ }
+ }
+
+ // Check the result type
+ if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
+ return expectedResultType == getResult().getType();
+ })) {
+ errorStream
+ << "Could not match allowed types for the result; expected one of ";
+ llvm::interleaveComma(expectedResult, errorStream);
+ errorStream << " but got " << getResult().getType();
+ return emitOpError(errorMessage);
+ }
+
+ // Ensure int4/int8 MMA variants specify the accum overflow behavior
+ // attribute.
+ if (isInt4PtxType(*getMultiplicandAPtxType()) ||
+ isInt8PtxType(*getMultiplicandAPtxType())) {
+ if (!getIntOverflowBehavior())
+ return emitOpError("op requires " +
+ getIntOverflowBehaviorAttrName().strref() +
+ " attribute");
+ }
+
+ // Validate sparse metadata type (should be i32)
+ if (!getSparseMetadata().getType().isInteger(32)) {
+ return emitOpError() << "sparse metadata must be i32 type";
+ }
+
+ // Validate sparsity selector type (should be i32)
+ if (!getSparsitySelector().getType().isInteger(32)) {
+ return emitOpError() << "sparsity selector must be i32 type";
+ }
+
+ return success();
+}
+
LogicalResult ShflOp::verify() {
auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-kind.mlir b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-kind.mlir
new file mode 100644
index 0000000000000..ff3e91b89016d
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-kind.mlir
@@ -0,0 +1,221 @@
+// RUN: mlir-opt %s -split-input-file | FileCheck %s
+
+// This file contains tests for sparse MMA (mma.sp.sync) operations with KIND variants.
+// The kind::f8f6f4 variant was introduced in PTX ISA 8.7 for sm_90+ architectures.
+//
+// Based on PTX ISA documentation:
+// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-for-sparse-mma
+//
+// KIND::F8F6F4 enables:
+// - Additional FP8 types: e3m2, e2m3, e2m1
+// - F16 accumulator for m16n8k64 FP8 operations
+// - Mixed-precision FP8 computations
+//
+// Requirements:
+// - ONLY works with ordered metadata (sp::ordered_metadata)
+// - ONLY for shape m16n8k64
+// - ONLY for FP8 types (not integers or other floats)
+
+// =============================================================================
+// FP8 e4m3 Sparse MMA with KIND (m16n8k64)
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e4m3_f16
+func.func @nvvm_mma_sp_kind_m16n8k64_e4m3_f16(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {kind = #nvvm.mma_kind<f8f6f4>,
+ orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e4m3_f32
+func.func @nvvm_mma_sp_kind_m16n8k64_e4m3_f32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {kind = #nvvm.mma_kind<f8f6f4>,
+ orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// =============================================================================
+// FP8 e5m2 Sparse MMA with KIND (m16n8k64)
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e5m2_f16
+func.func @nvvm_mma_sp_kind_m16n8k64_e5m2_f16(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {kind = #nvvm.mma_kind<f8f6f4>,
+ orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e5m2_f32
+func.func @nvvm_mma_sp_kind_m16n8k64_e5m2_f32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {kind = #nvvm.mma_kind<f8f6f4>,
+ orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// =============================================================================
+// FP8 e3m2 Sparse MMA with KIND (m16n8k64)
+// NOTE: e3m2 is ONLY available with kind::f8f6f4
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e3m2_f16
+func.func @nvvm_mma_sp_kind_m16n8k64_e3m2_f16(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e3m2>, multiplicandBPtxType = #nvvm.mma_type<e3m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {kind = #nvvm.mma_kind<f8f6f4>,
+ orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<e3m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e3m2>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e3m2_f32
+func.func @nvvm_mma_sp_kind_m16n8k64_e3m2_f32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e3m2>, multiplicandBPtxType = #nvvm.mma_type<e3m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {kind = #nvvm.mma_kind<f8f6f4>,
+ orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<e3m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e3m2>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// =============================================================================
+// FP8 e2m3 Sparse MMA with KIND (m16n8k64)
+// NOTE: e2m3 is ONLY available with kind::f8f6f4
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e2m3_f16
+func.func @nvvm_mma_sp_kind_m16n8k64_e2m3_f16(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e2m3>, multiplicandBPtxType = #nvvm.mma_type<e2m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {kind = #nvvm.mma_kind<f8f6f4>,
+ orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<e2m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m3>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e2m3_f32
+func.func @nvvm_mma_sp_kind_m16n8k64_e2m3_f32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e2m3>, multiplicandBPtxType = #nvvm.mma_type<e2m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {kind = #nvvm.mma_kind<f8f6f4>,
+ orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<e2m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m3>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// =============================================================================
+// FP8 e2m1 Sparse MMA with KIND (m16n8k64)
+// NOTE: e2m1 is ONLY available with kind::f8f6f4
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e2m1_f16
+func.func @nvvm_mma_sp_kind_m16n8k64_e2m1_f16(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e2m1>, multiplicandBPtxType = #nvvm.mma_type<e2m1>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {kind = #nvvm.mma_kind<f8f6f4>,
+ orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e2m1_f32
+func.func @nvvm_mma_sp_kind_m16n8k64_e2m1_f32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e2m1>, multiplicandBPtxType = #nvvm.mma_type<e2m1>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {kind = #nvvm.mma_kind<f8f6f4>,
+ orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-ordered.mlir b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-ordered.mlir
new file mode 100644
index 0000000000000..a4e2812e54c12
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-ordered.mlir
@@ -0,0 +1,411 @@
+// RUN: mlir-opt %s -split-input-file | FileCheck %s
+
+// This file contains tests for sparse MMA (mma.sp.sync) operations with ORDERED metadata.
+// The ordered metadata variant was introduced in PTX ISA 8.5 for sm_90+ architectures.
+//
+// Based on PTX ISA documentation:
+// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-for-sparse-mma
+//
+// Ordered metadata provides an alternative metadata ordering for 2:4 structured sparsity
+// that can offer better performance on newer architectures.
+
+// =============================================================================
+// F16 Sparse MMA Operations with Ordered Metadata (m16n8k16)
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k16_f16_f16
+func.func @nvvm_mma_sp_ordered_m16n8k16_f16_f16(
+ %a0 : vector<2xf16>, %a1 : vector<2xf16>,
+ %b0 : vector<2xf16>, %b1 : vector<2xf16>,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 16>}
+ : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k16_f16_f32
+func.func @nvvm_mma_sp_ordered_m16n8k16_f16_f32(
+ %a0 : vector<2xf16>, %a1 : vector<2xf16>,
+ %b0 : vector<2xf16>, %b1 : vector<2xf16>,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 16>}
+ : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// =============================================================================
+// F16 Sparse MMA Operations with Ordered Metadata (m16n8k32)
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_f16_f16
+func.func @nvvm_mma_sp_ordered_m16n8k32_f16_f16(
+ %a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>,
+ %b0 : vector<2xf16>, %b1 : vector<2xf16>, %b2 : vector<2xf16>, %b3 : vector<2xf16>,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_f16_f32
+func.func @nvvm_mma_sp_ordered_m16n8k32_f16_f32(
+ %a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>,
+ %b0 : vector<2xf16>, %b1 : vector<2xf16>, %b2 : vector<2xf16>, %b3 : vector<2xf16>,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// =============================================================================
+// BF16 Sparse MMA Operations with Ordered Metadata
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k16_bf16_f32
+func.func @nvvm_mma_sp_ordered_m16n8k16_bf16_f32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<bf16>,
+ multiplicandBPtxType = #nvvm.mma_type<bf16>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 16>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_bf16_f32
+func.func @nvvm_mma_sp_ordered_m16n8k32_bf16_f32(
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<bf16>,
+ multiplicandBPtxType = #nvvm.mma_type<bf16>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// =============================================================================
+// TF32 Sparse MMA Operations with Ordered Metadata
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k8_tf32_f32
+func.func @nvvm_mma_sp_ordered_m16n8k8_tf32_f32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<tf32>,
+ multiplicandBPtxType = #nvvm.mma_type<tf32>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 8>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k16_tf32_f32
+func.func @nvvm_mma_sp_ordered_m16n8k16_tf32_f32(
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<tf32>,
+ multiplicandBPtxType = #nvvm.mma_type<tf32>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 16>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// =============================================================================
+// Integer (s8) Sparse MMA Operations with Ordered Metadata
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_s8_s32
+func.func @nvvm_mma_sp_ordered_m16n8k32_s8_s32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<s8>,
+ multiplicandBPtxType = #nvvm.mma_type<s8>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_s8_s32_satfinite
+func.func @nvvm_mma_sp_ordered_m16n8k32_s8_s32_satfinite(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<s8>,
+ multiplicandBPtxType = #nvvm.mma_type<s8>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_s8_s32
+func.func @nvvm_mma_sp_ordered_m16n8k64_s8_s32(
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<s8>,
+ multiplicandBPtxType = #nvvm.mma_type<s8>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// =============================================================================
+// Integer (u8) Sparse MMA Operations with Ordered Metadata
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_u8_s32
+func.func @nvvm_mma_sp_ordered_m16n8k32_u8_s32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u8>, multiplicandBPtxType = #nvvm.mma_type<u8>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<u8>,
+ multiplicandBPtxType = #nvvm.mma_type<u8>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_u8_s32
+func.func @nvvm_mma_sp_ordered_m16n8k64_u8_s32(
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u8>, multiplicandBPtxType = #nvvm.mma_type<u8>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<u8>,
+ multiplicandBPtxType = #nvvm.mma_type<u8>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// =============================================================================
+// Sub-byte Integer (s4) Sparse MMA Operations with Ordered Metadata
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_s4_s32
+func.func @nvvm_mma_sp_ordered_m16n8k64_s4_s32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<s4>,
+ multiplicandBPtxType = #nvvm.mma_type<s4>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k128_s4_s32
+func.func @nvvm_mma_sp_ordered_m16n8k128_s4_s32(
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<s4>,
+ multiplicandBPtxType = #nvvm.mma_type<s4>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 128>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// =============================================================================
+// Sub-byte Integer (u4) Sparse MMA Operations with Ordered Metadata
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_u4_s32
+func.func @nvvm_mma_sp_ordered_m16n8k64_u4_s32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u4>, multiplicandBPtxType = #nvvm.mma_type<u4>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<u4>,
+ multiplicandBPtxType = #nvvm.mma_type<u4>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k128_u4_s32
+func.func @nvvm_mma_sp_ordered_m16n8k128_u4_s32(
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u4>, multiplicandBPtxType = #nvvm.mma_type<u4>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<u4>,
+ multiplicandBPtxType = #nvvm.mma_type<u4>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 128>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// =============================================================================
+// FP8 (e4m3) Sparse MMA Operations with Ordered Metadata
+// NOTE: FP8 ordered metadata requires PTX ISA 8.7+ and sm_90+
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_e4m3_f16
+func.func @nvvm_mma_sp_ordered_m16n8k64_e4m3_f16(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_e4m3_f32
+func.func @nvvm_mma_sp_ordered_m16n8k64_e4m3_f32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// =============================================================================
+// FP8 (e5m2) Sparse MMA Operations with Ordered Metadata
+// NOTE: FP8 ordered metadata requires PTX ISA 8.7+ and sm_90+
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_e5m2_f16
+func.func @nvvm_mma_sp_ordered_m16n8k64_e5m2_f16(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_e5m2_f32
+func.func @nvvm_mma_sp_ordered_m16n8k64_e5m2_f32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-mma-sp.mlir b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp.mlir
new file mode 100644
index 0000000000000..e7122aac61baf
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp.mlir
@@ -0,0 +1,390 @@
+// RUN: mlir-opt %s -split-input-file | FileCheck %s
+
+// This file contains tests for all sparse MMA (mma.sp.sync) operations in the NVVM dialect
+// Based on PTX ISA documentation:
+// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-for-sparse-mma
+//
+// Sparse MMA operations follow 2:4 structured sparsity where 2 out of every 4 elements
+// in the A operand are non-zero. The A operand is provided in compressed form,
+// and sparseMetadata provides the sparsity indices.
+//
+// NOTE: These tests use the default (standard) metadata ordering.
+// For ordered metadata tests (PTX ISA 8.5+, sm_90+), see nvvm-mma-sp-ordered.mlir.
+
+// =============================================================================
+// F16 Sparse MMA Operations (m16n8k16)
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k16_f16_f16
+func.func @nvvm_mma_sp_m16n8k16_f16_f16(
+ %a0 : vector<2xf16>, %a1 : vector<2xf16>,
+ %b0 : vector<2xf16>, %b1 : vector<2xf16>,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 16>}
+ : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k16_f16_f32
+func.func @nvvm_mma_sp_m16n8k16_f16_f32(
+ %a0 : vector<2xf16>, %a1 : vector<2xf16>,
+ %b0 : vector<2xf16>, %b1 : vector<2xf16>,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 16>}
+ : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// =============================================================================
+// F16 Sparse MMA Operations (m16n8k32)
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_f16_f16
+func.func @nvvm_mma_sp_m16n8k32_f16_f16(
+ %a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>,
+ %b0 : vector<2xf16>, %b1 : vector<2xf16>, %b2 : vector<2xf16>, %b3 : vector<2xf16>,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_f16_f32
+func.func @nvvm_mma_sp_m16n8k32_f16_f32(
+ %a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>,
+ %b0 : vector<2xf16>, %b1 : vector<2xf16>, %b2 : vector<2xf16>, %b3 : vector<2xf16>,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// =============================================================================
+// BF16 Sparse MMA Operations
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k16_bf16_f32
+func.func @nvvm_mma_sp_m16n8k16_bf16_f32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<bf16>,
+ multiplicandBPtxType = #nvvm.mma_type<bf16>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 16>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_bf16_f32
+func.func @nvvm_mma_sp_m16n8k32_bf16_f32(
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<bf16>,
+ multiplicandBPtxType = #nvvm.mma_type<bf16>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// =============================================================================
+// TF32 Sparse MMA Operations
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k8_tf32_f32
+func.func @nvvm_mma_sp_m16n8k8_tf32_f32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<tf32>,
+ multiplicandBPtxType = #nvvm.mma_type<tf32>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 8>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k16_tf32_f32
+func.func @nvvm_mma_sp_m16n8k16_tf32_f32(
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<tf32>,
+ multiplicandBPtxType = #nvvm.mma_type<tf32>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 16>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// =============================================================================
+// Integer (s8) Sparse MMA Operations
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_s8_s32
+func.func @nvvm_mma_sp_m16n8k32_s8_s32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<s8>,
+ multiplicandBPtxType = #nvvm.mma_type<s8>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_s8_s32_satfinite
+func.func @nvvm_mma_sp_m16n8k32_s8_s32_satfinite(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<s8>,
+ multiplicandBPtxType = #nvvm.mma_type<s8>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_s8_s32
+func.func @nvvm_mma_sp_m16n8k64_s8_s32(
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<s8>,
+ multiplicandBPtxType = #nvvm.mma_type<s8>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// =============================================================================
+// Integer (u8) Sparse MMA Operations
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_u8_s32
+func.func @nvvm_mma_sp_m16n8k32_u8_s32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u8>, multiplicandBPtxType = #nvvm.mma_type<u8>, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<u8>,
+ multiplicandBPtxType = #nvvm.mma_type<u8>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_u8_s32
+func.func @nvvm_mma_sp_m16n8k64_u8_s32(
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u8>, multiplicandBPtxType = #nvvm.mma_type<u8>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<u8>,
+ multiplicandBPtxType = #nvvm.mma_type<u8>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// =============================================================================
+// Sub-byte Integer (s4) Sparse MMA Operations
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_s4_s32
+func.func @nvvm_mma_sp_m16n8k64_s4_s32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<s4>,
+ multiplicandBPtxType = #nvvm.mma_type<s4>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k128_s4_s32
+func.func @nvvm_mma_sp_m16n8k128_s4_s32(
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>, shape = #nvvm.shape<m = 16, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<s4>,
+ multiplicandBPtxType = #nvvm.mma_type<s4>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 128>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// =============================================================================
+// Sub-byte Integer (u4) Sparse MMA Operations
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_u4_s32
+func.func @nvvm_mma_sp_m16n8k64_u4_s32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u4>, multiplicandBPtxType = #nvvm.mma_type<u4>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<u4>,
+ multiplicandBPtxType = #nvvm.mma_type<u4>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k128_u4_s32
+func.func @nvvm_mma_sp_m16n8k128_u4_s32(
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u4>, multiplicandBPtxType = #nvvm.mma_type<u4>, shape = #nvvm.shape<m = 16, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<u4>,
+ multiplicandBPtxType = #nvvm.mma_type<u4>,
+ intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 128>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ return
+}
+
+// =============================================================================
+// FP8 (e4m3) Sparse MMA Operations
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_e4m3_f16
+func.func @nvvm_mma_sp_m16n8k64_e4m3_f16(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_e4m3_f32
+func.func @nvvm_mma_sp_m16n8k64_e4m3_f32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
+// =============================================================================
+// FP8 (e5m2) Sparse MMA Operations
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_e5m2_f16
+func.func @nvvm_mma_sp_m16n8k64_e5m2_f16(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_e5m2_f32
+func.func @nvvm_mma_sp_m16n8k64_e5m2_f32(
+ %a0 : i32, %a1 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
+ %meta : i32, %sel : i32) {
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return
+}
+
More information about the Mlir-commits
mailing list