[Mlir-commits] [mlir] [MLIR] Supported sparse MMA intrinsics in the MLIR->NVVM IR->NVPTX flow (PR #168686)
Kirill Vedernikov
llvmlistbot at llvm.org
Thu Nov 20 00:35:00 PST 2025
================
@@ -2772,6 +2832,221 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
let hasVerifier = 1;
}
+/// Generate enum value of the mma.sync intrinsic.
+class MMA_SP_SYNC_NAME<string Metadata, string Kind, int Satfinite,
+ WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
+ string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+ string id = "llvm::Intrinsic::nvvm_mma"
+ # "_" # !subst("::", "_", Metadata)
+ # "_" # A.geom
+ # "_row_col"
+ # !if(!ne(Kind, ""), !strconcat("_", !subst("::", "_", Kind)), "")
+ # !if(Satfinite, "_satfinite", "")
+ # signature;
+}
+
+// Returns true if this combination of layout/kind/satf for MMA.SP ops is supported;
+// false otherwise.
+// E.g.
+// if NVVM_MMA_SP_SUPPORTED<...>.ret then
+// def : FOO<>; // The record will only be defined for supported ops.
+//
+class NVVM_MMA_SP_SUPPORTED<list<WMMA_REGS> frags, string metadata,
+ string kind, int satf> {
+ // MMA.SP ops check both layouts.
+ string a_type = frags[0].ptx_elt_type;
+ string b_type = frags[1].ptx_elt_type;
+ string c_type = frags[2].ptx_elt_type;
+ string d_type = frags[3].ptx_elt_type;
+ string geom = frags[0].geom;
+
+ bit is_int = !or(!eq(a_type, "s8"),
+ !eq(a_type, "u8"),
+ !eq(a_type, "s4"),
+ !eq(a_type, "u4"));
+
+ bit ret = !cond(
+
+ // Limit satf to valid types
+ !and(!eq(satf, 1),
+ !eq(is_int, 0)): false,
+
+ // f16/bf16/tf32 requires A and B to be the same type.
+ !and(!or(!eq(a_type, "f16"),
+ !eq(a_type, "bf16"),
+ !eq(a_type, "tf32")),
+ !ne(a_type, b_type)): false,
+
+ // m16n8k16, m16n8k32 and m16n8k64 requires C and D to be the same type.
+ !and(!or(!eq(geom, "m16n8k16"),
+ !eq(geom, "m16n8k32"),
+ !eq(geom, "m16n8k64")),
+ !ne(c_type, d_type)): false,
+
+ !and(!eq(kind, ""),
+ !or(!eq(a_type, "e3m2"),
+ !eq(a_type, "e2m3"),
+ !eq(a_type, "e2m1"),
+ !eq(b_type, "e3m2"),
+ !eq(b_type, "e2m3"),
+ !eq(b_type, "e2m1"))): false,
+
+ !and(!eq(kind, ""),
+ !eq(geom, "m16n8k64"),
+ !or(!eq(c_type, "f16"),
+ !eq(d_type, "f16"))): false,
+
+ !and(!ne(kind, ""),
+ !or(!eq(metadata, "sp"),
+ !ne(geom, "m16n8k64"),
+ !eq(is_int, 1))): false,
+
+ // All other are OK.
+ true: true
+ );
+}
+
+/// Helper to create the mapping between the configuration and the mma.sp.sync
+/// intrinsic enum value.
+class MMA_SP_SYNC_INTR {
+ list<list<list<list<string>>>> cond0 =
+ !foreach(op, NVVM_MMA_OPS.all_mma_sp_sync_ops,
+ !foreach(metadata, ["sp", "sp::ordered_metadata"],
+ !foreach(kind, ["", "kind::f8f6f4"],
+ !foreach (satf, [0, 1],
+ !if(NVVM_MMA_SP_SUPPORTED<op, metadata, kind, satf>.ret,
+ "if (m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k
+ # " && \"" # op[0].ptx_elt_type # "\" == eltypeA"
+ # " && \"" # op[1].ptx_elt_type # "\" == eltypeB"
+ # " && \"" # op[2].ptx_elt_type # "\" == eltypeC"
+ # " && \"" # op[3].ptx_elt_type # "\" == eltypeD"
+ # " && (satf.has_value() ? " # satf # " == static_cast<int>(*satf) : true)"
+ # " && " # !if(!eq(metadata, "sp"), "!orderedMetadata", "orderedMetadata")
+ # " && " # !if(!eq(kind, ""), "!hasKind", "hasKind") # ")\n"
----------------
kvederni wrote:
It has been fixed in 156919a.
https://github.com/llvm/llvm-project/pull/168686
More information about the Mlir-commits
mailing list